diff options
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 96 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 30 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/step_name_to_steps.py | 41 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/steps_on_startup.py | 5 |
6 files changed, 79 insertions, 104 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 8b3fb97d..118744f9 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -17,7 +17,6 @@ from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK import asyncio -from ..libs.util.step_name_to_steps import get_step_from_name from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors @@ -136,8 +135,7 @@ class Autopilot(ContinueBaseModel): traceback = get_tb_func(output) if traceback is not None: for tb_step in self.continue_sdk.config.on_traceback: - step = get_step_from_name( - tb_step.step_name, {"output": output, **tb_step.params}) + step = tb_step.step({"output": output, **tb_step.params}) await self._run_singular_step(step) _highlighted_ranges: List[HighlightedRangeContext] = [] diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 55f5bc60..8c7ed2fd 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,14 +1,15 @@ import json import os +from .main import Step from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict +from typing import List, Literal, Optional, Dict, Type, Union import yaml class SlashCommand(BaseModel): name: str description: str - step_name: str + step: Type[Step] params: Optional[Dict] = {} @@ -19,54 +20,15 @@ class CustomCommand(BaseModel): class OnTracebackSteps(BaseModel): - step_name: str + step: Type[Step] params: Optional[Dict] = {} -DEFAULT_SLASH_COMMANDS = [ - # SlashCommand( - # name="pytest", - # description="Write pytest unit tests for the current file", - # step_name="WritePytestsRecipe", - # params=??) - SlashCommand( - name="edit", - description="Edit code in the current file or the highlighted code", - step_name="EditHighlightedCodeStep", - ), - # SlashCommand( - # name="explain", - # description="Reply to instructions or a question with previous steps and the highlighted code or current file as context", - # step_name="SimpleChatStep", - # ), - SlashCommand( - name="config", - description="Open the config file to create new and edit existing slash commands", - step_name="OpenConfigStep", - ), - SlashCommand( - name="comment", - description="Write comments for the current file or highlighted code", - step_name="CommentCodeStep", - ), - SlashCommand( - name="feedback", - description="Send feedback to improve Continue", - step_name="FeedbackStep", - ), - SlashCommand( - name="clear", - description="Clear step history", - step_name="ClearHistoryStep", - ) -] - - class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. """ - steps_on_startup: Optional[Dict[str, Dict]] = {} + steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True @@ -77,14 +39,54 @@ class ContinueConfig(BaseModel): description="This is an example custom command. Use /config to edit it and create more", prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", )] - slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS - on_traceback: Optional[List[OnTracebackSteps]] = [ - OnTracebackSteps(step_name="DefaultOnTracebackStep")] + slash_commands: Optional[List[SlashCommand]] = [] + on_traceback: Optional[List[OnTracebackSteps]] = [] # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): - return DEFAULT_SLASH_COMMANDS + from ..steps.core.core import UserInputStep + from ..steps.open_config import OpenConfigStep + from ..steps.clear_history import ClearHistoryStep + from ..steps.on_traceback import DefaultOnTracebackStep + from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe + from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe + from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe + from ..recipes.AddTransformRecipe.main import AddTransformRecipe + from ..steps.feedback import FeedbackStep + from ..steps.comment_code import CommentCodeStep + from ..steps.chat import SimpleChatStep + from ..steps.main import EditHighlightedCodeStep + + DEFAULT_SLASH_COMMANDS = [ + SlashCommand( + name="edit", + description="Edit code in the current file or the highlighted code", + step=EditHighlightedCodeStep, + ), + SlashCommand( + name="config", + description="Open the config file to create new and edit existing slash commands", + step=OpenConfigStep, + ), + SlashCommand( + name="comment", + description="Write comments for the current file or highlighted code", + step=CommentCodeStep, + ), + SlashCommand( + name="feedback", + description="Send feedback to improve Continue", + step=FeedbackStep, + ), + SlashCommand( + name="clear", + description="Clear step history", + step=ClearHistoryStep, + ) + ] + + return DEFAULT_SLASH_COMMANDS + v def load_config(config_file: str) -> ContinueConfig: diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index b8363df2..93946cb6 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -18,7 +18,6 @@ from ..steps.react import NLDecisionStep from ..steps.chat import SimpleChatStep, ChatWithFunctions, EditFileStep, AddFileStep from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..steps.core.core import MessageStep -from ..libs.util.step_name_to_steps import get_step_from_name from ..steps.custom_command import CustomCommandStep @@ -34,7 +33,11 @@ def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: if slash_command.name == command_name[1:]: params = slash_command.params params["user_input"] = after_command - return get_step_from_name(slash_command.step_name, params) + try: + return slash_command.step(**params) + except TypeError as e: + raise Exception( + f"Incorrect params used for slash command '{command_name}': {e}") return None diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index ed670799..f67198a8 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -166,17 +166,31 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) + @staticmethod + def load_config_dot_py(path: str) -> ContinueConfig: + # Use importlib to load the config file config.py at the given path + import importlib.util + spec = importlib.util.spec_from_file_location("config", path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + return config.config + @property def config(self) -> ContinueConfig: + # TODO: Workspace config files should override global dir = self.ide.workspace_directory - yaml_path = os.path.join(dir, '.continue', 'config.yaml') - json_path = os.path.join(dir, '.continue', 'config.json') - if os.path.exists(yaml_path): - return load_config(yaml_path) - elif os.path.exists(json_path): - return load_config(json_path) - else: - return load_global_config() + path = os.path.join(dir, '.continue', 'config.py') + if not os.path.exists(path): + global_dir = os.path.expanduser('~/.continue') + if not os.path.exists(global_dir): + os.mkdir(global_dir) + path = os.path.join(global_dir, 'config.py') + if not os.path.exists(path): + # Need to copy over the default config + return ContinueConfig() + + config = ContinueSDK.load_config_dot_py(path) + return config def update_default_model(self, model: str): config = self.config diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py deleted file mode 100644 index d329e110..00000000 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Dict - -from ...core.main import Step -from ...steps.core.core import UserInputStep -from ...steps.main import EditHighlightedCodeStep -from ...steps.chat import SimpleChatStep -from ...steps.comment_code import CommentCodeStep -from ...steps.feedback import FeedbackStep -from ...recipes.AddTransformRecipe.main import AddTransformRecipe -from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ...steps.on_traceback import DefaultOnTracebackStep -from ...steps.clear_history import ClearHistoryStep -from ...steps.open_config import OpenConfigStep - -# This mapping is used to convert from string in ContinueConfig json to corresponding Step class. -# Used for example in slash_commands and steps_on_startup -step_name_to_step_class = { - "UserInputStep": UserInputStep, - "EditHighlightedCodeStep": EditHighlightedCodeStep, - "SimpleChatStep": SimpleChatStep, - "CommentCodeStep": CommentCodeStep, - "FeedbackStep": FeedbackStep, - "AddTransformRecipe": AddTransformRecipe, - "CreatePipelineRecipe": CreatePipelineRecipe, - "DDtoBQRecipe": DDtoBQRecipe, - "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, - "DefaultOnTracebackStep": DefaultOnTracebackStep, - "ClearHistoryStep": ClearHistoryStep, - "OpenConfigStep": OpenConfigStep -} - - -def get_step_from_name(step_name: str, params: Dict) -> Step: - try: - return step_name_to_step_class[step_name](**params) - except: - print( - f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") - raise diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py index 365cbe1a..318c28df 100644 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/steps/steps_on_startup.py @@ -6,7 +6,6 @@ from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.AddTransformRecipe.main import AddTransformRecipe -from ..libs.util.step_name_to_steps import get_step_from_name class StepsOnStartupStep(Step): @@ -18,6 +17,6 @@ class StepsOnStartupStep(Step): async def run(self, sdk: ContinueSDK): steps_on_startup = sdk.config.steps_on_startup - for step_name, step_params in steps_on_startup.items(): - step = get_step_from_name(step_name, step_params) + for step_type in steps_on_startup: + step = step_type() await sdk.run_step(step) |