summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/policy.py18
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
-rw-r--r--continuedev/src/continuedev/libs/util/templating.py10
-rw-r--r--continuedev/src/continuedev/steps/custom_command.py8
4 files changed, 16 insertions, 24 deletions
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index d007c92b..1000f0f4 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,22 +1,12 @@
from textwrap import dedent
-from typing import List, Tuple, Type, Union
+from typing import Union
+from ..steps.chat import SimpleChatStep
from ..steps.welcome import WelcomeStep
from .config import ContinueConfig
-from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma
from ..steps.steps_on_startup import StepsOnStartupStep
-from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
-from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe
-from ..recipes.AddTransformRecipe.main import AddTransformRecipe
-from .main import Step, Validator, History, Policy
-from .observation import Observation, TracebackObservation, UserInputObservation
-from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep
-from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe
-from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep
-from ..steps.comment_code import CommentCodeStep
-from ..steps.react import NLDecisionStep
-from ..steps.chat import SimpleChatStep, ChatWithFunctions, EditFileStep, AddFileStep
-from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
+from .main import Step, History, Policy
+from .observation import UserInputObservation
from ..steps.core.core import MessageStep
from ..libs.util.step_name_to_steps import get_step_from_name
from ..steps.custom_command import CustomCommandStep
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c81d8aa4..987aa722 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,7 +1,7 @@
import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
-from .templating import render_system_message
+from .templating import render_templated_string
import tiktoken
aliases = {
@@ -112,7 +112,7 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
if system_message is not None:
# NOTE: System message takes second precedence to user prompt, so it is placed just before
# but move back to start after processing
- rendered_system_message = render_system_message(system_message)
+ rendered_system_message = render_templated_string(system_message)
system_chat_msg = ChatMessage(
role="system", content=rendered_system_message, summary=rendered_system_message)
# insert at second-to-last position
diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py
index ebfc2e31..bb922ad7 100644
--- a/continuedev/src/continuedev/libs/util/templating.py
+++ b/continuedev/src/continuedev/libs/util/templating.py
@@ -16,19 +16,19 @@ def escape_var(var: str) -> str:
return var.replace(os.path.sep, '').replace('.', '')
-def render_system_message(system_message: str) -> str:
+def render_templated_string(template: str) -> str:
"""
- Render system message with mustache syntax.
+ Render system message or other templated string with mustache syntax.
Right now it only supports rendering absolute file paths as their contents.
"""
- vars = get_vars_in_template(system_message)
+ vars = get_vars_in_template(template)
args = {}
for var in vars:
if var.startswith(os.path.sep):
# Escape vars which are filenames, because mustache doesn't allow / in variable names
escaped_var = escape_var(var)
- system_message = system_message.replace(
+ template = template.replace(
var, escaped_var)
if os.path.exists(var):
@@ -36,4 +36,4 @@ def render_system_message(system_message: str) -> str:
else:
args[escaped_var] = ''
- return chevron.render(system_message, args)
+ return chevron.render(template, args)
diff --git a/continuedev/src/continuedev/steps/custom_command.py b/continuedev/src/continuedev/steps/custom_command.py
index 5a56efb0..d96ac8e2 100644
--- a/continuedev/src/continuedev/steps/custom_command.py
+++ b/continuedev/src/continuedev/steps/custom_command.py
@@ -1,7 +1,7 @@
+from ..libs.util.templating import render_templated_string
from ..core.main import Step
from ..core.sdk import ContinueSDK
-from ..steps.core.core import UserInputStep
-from ..steps.chat import ChatWithFunctions, SimpleChatStep
+from ..steps.chat import SimpleChatStep
class CustomCommandStep(Step):
@@ -15,7 +15,9 @@ class CustomCommandStep(Step):
return self.prompt
async def run(self, sdk: ContinueSDK):
- prompt_user_input = f"Task: {self.prompt}. Additional info: {self.user_input}"
+ task = render_templated_string(self.prompt)
+
+ prompt_user_input = f"Task: {task}. Additional info: {self.user_input}"
messages = await sdk.get_chat_context()
# Find the last chat message with this slash command and replace it with the user input
for i in range(len(messages) - 1, -1, -1):