summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py4
-rw-r--r--continuedev/src/continuedev/core/config.py94
-rw-r--r--continuedev/src/continuedev/core/policy.py10
-rw-r--r--continuedev/src/continuedev/core/sdk.py33
-rw-r--r--continuedev/src/continuedev/libs/util/step_name_to_steps.py43
-rw-r--r--continuedev/src/continuedev/steps/custom_command.py4
-rw-r--r--continuedev/src/continuedev/steps/steps_on_startup.py5
7 files changed, 73 insertions, 120 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index afbfc7ed..abda50b0 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -16,7 +16,6 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol
from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
from ..libs.util.telemetry import capture_event
from .sdk import ContinueSDK
-from ..libs.util.step_name_to_steps import get_step_from_name
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
from openai import error as openai_errors
from ..libs.util.create_async_task import create_async_task
@@ -157,8 +156,7 @@ class Autopilot(ContinueBaseModel):
traceback = get_tb_func(output)
if traceback is not None:
for tb_step in self.continue_sdk.config.on_traceback:
- step = get_step_from_name(
- tb_step.step_name, {"output": output, **tb_step.params})
+ step = tb_step.step({"output": output, **tb_step.params})
await self._run_singular_step(step)
_highlighted_ranges: List[HighlightedRangeContext] = []
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 70c4876e..54f15143 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,14 +1,15 @@
import json
import os
+from .main import Step
from pydantic import BaseModel, validator
-from typing import List, Literal, Optional, Dict
+from typing import List, Literal, Optional, Dict, Type, Union
import yaml
class SlashCommand(BaseModel):
name: str
description: str
- step_name: str
+ step: Type[Step]
params: Optional[Dict] = {}
@@ -19,54 +20,10 @@ class CustomCommand(BaseModel):
class OnTracebackSteps(BaseModel):
- step_name: str
+ step: Type[Step]
params: Optional[Dict] = {}
-DEFAULT_SLASH_COMMANDS = [
- # SlashCommand(
- # name="pytest",
- # description="Write pytest unit tests for the current file",
- # step_name="WritePytestsRecipe",
- # params=??)
- SlashCommand(
- name="edit",
- description="Edit code in the current file or the highlighted code",
- step_name="EditHighlightedCodeStep",
- ),
- # SlashCommand(
- # name="explain",
- # description="Reply to instructions or a question with previous steps and the highlighted code or current file as context",
- # step_name="SimpleChatStep",
- # ),
- SlashCommand(
- name="config",
- description="Open the config file to create new and edit existing slash commands",
- step_name="OpenConfigStep",
- ),
- SlashCommand(
- name="help",
- description="Ask a question like '/help what is given to the llm as context?'",
- step_name="HelpStep",
- ),
- SlashCommand(
- name="comment",
- description="Write comments for the current file or highlighted code",
- step_name="CommentCodeStep",
- ),
- SlashCommand(
- name="feedback",
- description="Send feedback to improve Continue",
- step_name="FeedbackStep",
- ),
- SlashCommand(
- name="clear",
- description="Clear step history",
- step_name="ClearHistoryStep",
- )
-]
-
-
class AzureInfo(BaseModel):
endpoint: str
engine: str
@@ -77,7 +34,7 @@ class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
"""
- steps_on_startup: Optional[Dict[str, Dict]] = {}
+ steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
@@ -88,16 +45,49 @@ class ContinueConfig(BaseModel):
description="This is an example custom command. Use /config to edit it and create more",
prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.",
)]
- slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS
- on_traceback: Optional[List[OnTracebackSteps]] = [
- OnTracebackSteps(step_name="DefaultOnTracebackStep")]
+ slash_commands: Optional[List[SlashCommand]] = []
+ on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
azure_openai_info: Optional[AzureInfo] = None
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
def default_slash_commands_validator(cls, v):
- return DEFAULT_SLASH_COMMANDS
+ from ..steps.open_config import OpenConfigStep
+ from ..steps.clear_history import ClearHistoryStep
+ from ..steps.feedback import FeedbackStep
+ from ..steps.comment_code import CommentCodeStep
+ from ..steps.main import EditHighlightedCodeStep
+
+ DEFAULT_SLASH_COMMANDS = [
+ SlashCommand(
+ name="edit",
+ description="Edit code in the current file or the highlighted code",
+ step=EditHighlightedCodeStep,
+ ),
+ SlashCommand(
+ name="config",
+ description="Open the config file to create new and edit existing slash commands",
+ step=OpenConfigStep,
+ ),
+ SlashCommand(
+ name="comment",
+ description="Write comments for the current file or highlighted code",
+ step=CommentCodeStep,
+ ),
+ SlashCommand(
+ name="feedback",
+ description="Send feedback to improve Continue",
+ step=FeedbackStep,
+ ),
+ SlashCommand(
+ name="clear",
+ description="Clear step history",
+ step=ClearHistoryStep,
+ )
+ ]
+
+ return DEFAULT_SLASH_COMMANDS + v
@validator('temperature', pre=True)
def temperature_validator(cls, v):
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 1000f0f4..53e482fa 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -8,7 +8,6 @@ from ..steps.steps_on_startup import StepsOnStartupStep
from .main import Step, History, Policy
from .observation import UserInputObservation
from ..steps.core.core import MessageStep
-from ..libs.util.step_name_to_steps import get_step_from_name
from ..steps.custom_command import CustomCommandStep
@@ -24,7 +23,11 @@ def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
if slash_command.name == command_name[1:]:
params = slash_command.params
params["user_input"] = after_command
- return get_step_from_name(slash_command.step_name, params)
+ try:
+ return slash_command.step(**params)
+ except TypeError as e:
+ raise Exception(
+ f"Incorrect params used for slash command '{command_name}': {e}")
return None
@@ -69,6 +72,9 @@ class DemoPolicy(Policy):
if custom_command is not None:
return custom_command
+ if user_input.startswith("/edit"):
+ return EditHighlightedCodeStep(user_input=user_input[5:])
+
return SimpleChatStep()
return None
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 37a51efa..4100efa6 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -144,20 +144,20 @@ class ContinueSDK(AbstractContinueSDK):
ide: AbstractIdeProtocolServer
models: Models
context: Context
+ config: ContinueConfig
__autopilot: Autopilot
def __init__(self, autopilot: Autopilot):
self.ide = autopilot.ide
self.__autopilot = autopilot
self.context = autopilot.context
- self.config = self._load_config()
@classmethod
async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
sdk = ContinueSDK(autopilot)
try:
- config = sdk._load_config()
+ config = sdk._load_config_dot_py()
sdk.config = config
except Exception as e:
print(e)
@@ -175,19 +175,6 @@ class ContinueSDK(AbstractContinueSDK):
sdk.models = await Models.create(sdk)
return sdk
- config: ContinueConfig
-
- def _load_config(self) -> ContinueConfig:
- dir = self.ide.workspace_directory
- yaml_path = os.path.join(dir, '.continue', 'config.yaml')
- json_path = os.path.join(dir, '.continue', 'config.json')
- if os.path.exists(yaml_path):
- return load_config(yaml_path)
- elif os.path.exists(json_path):
- return load_config(json_path)
- else:
- return load_global_config()
-
@property
def history(self) -> History:
return self.__autopilot.history
@@ -267,6 +254,22 @@ class ContinueSDK(AbstractContinueSDK):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
+ _last_valid_config: ContinueConfig = None
+
+ def _load_config_dot_py(self) -> ContinueConfig:
+ # Use importlib to load the config file config.py at the given path
+ path = os.path.join(os.path.expanduser("~"), ".continue", "config.py")
+ try:
+ import importlib.util
+ spec = importlib.util.spec_from_file_location("config", path)
+ config = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(config)
+ self._last_valid_config = config.config
+ return config.config
+ except Exception as e:
+ print("Error loading config.py: ", e)
+ return ContinueConfig() if self._last_valid_config is None else self._last_valid_config
+
def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)
) if only_editing else self.__autopilot._highlighted_ranges
diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
deleted file mode 100644
index 49056c81..00000000
--- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from typing import Dict
-
-from ...core.main import Step
-from ...steps.core.core import UserInputStep
-from ...steps.main import EditHighlightedCodeStep
-from ...steps.chat import SimpleChatStep
-from ...steps.comment_code import CommentCodeStep
-from ...steps.feedback import FeedbackStep
-from ...recipes.AddTransformRecipe.main import AddTransformRecipe
-from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
-from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe
-from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe
-from ...steps.on_traceback import DefaultOnTracebackStep
-from ...steps.clear_history import ClearHistoryStep
-from ...steps.open_config import OpenConfigStep
-from ...steps.help import HelpStep
-
-# This mapping is used to convert from string in ContinueConfig json to corresponding Step class.
-# Used for example in slash_commands and steps_on_startup
-step_name_to_step_class = {
- "UserInputStep": UserInputStep,
- "EditHighlightedCodeStep": EditHighlightedCodeStep,
- "SimpleChatStep": SimpleChatStep,
- "CommentCodeStep": CommentCodeStep,
- "FeedbackStep": FeedbackStep,
- "AddTransformRecipe": AddTransformRecipe,
- "CreatePipelineRecipe": CreatePipelineRecipe,
- "DDtoBQRecipe": DDtoBQRecipe,
- "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe,
- "DefaultOnTracebackStep": DefaultOnTracebackStep,
- "ClearHistoryStep": ClearHistoryStep,
- "OpenConfigStep": OpenConfigStep,
- "HelpStep": HelpStep,
-}
-
-
-def get_step_from_name(step_name: str, params: Dict) -> Step:
- try:
- return step_name_to_step_class[step_name](**params)
- except:
- print(
- f"Incorrect parameters for step {step_name}. Parameters provided were: {params}")
- raise
diff --git a/continuedev/src/continuedev/steps/custom_command.py b/continuedev/src/continuedev/steps/custom_command.py
index d96ac8e2..375900c1 100644
--- a/continuedev/src/continuedev/steps/custom_command.py
+++ b/continuedev/src/continuedev/steps/custom_command.py
@@ -1,5 +1,5 @@
from ..libs.util.templating import render_templated_string
-from ..core.main import Step
+from ..core.main import Models, Step
from ..core.sdk import ContinueSDK
from ..steps.chat import SimpleChatStep
@@ -11,7 +11,7 @@ class CustomCommandStep(Step):
slash_command: str
hide: bool = True
- async def describe(self):
+ async def describe(self, models: Models):
return self.prompt
async def run(self, sdk: ContinueSDK):
diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py
index 365cbe1a..318c28df 100644
--- a/continuedev/src/continuedev/steps/steps_on_startup.py
+++ b/continuedev/src/continuedev/steps/steps_on_startup.py
@@ -6,7 +6,6 @@ from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe
from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
from ..recipes.AddTransformRecipe.main import AddTransformRecipe
-from ..libs.util.step_name_to_steps import get_step_from_name
class StepsOnStartupStep(Step):
@@ -18,6 +17,6 @@ class StepsOnStartupStep(Step):
async def run(self, sdk: ContinueSDK):
steps_on_startup = sdk.config.steps_on_startup
- for step_name, step_params in steps_on_startup.items():
- step = get_step_from_name(step_name, step_params)
+ for step_type in steps_on_startup:
+ step = step_type()
await sdk.run_step(step)