diff options
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 94 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 33 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/step_name_to_steps.py | 43 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/steps_on_startup.py | 5 |
6 files changed, 71 insertions, 117 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index afbfc7ed..1f3e6323 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -157,8 +157,7 @@ class Autopilot(ContinueBaseModel): traceback = get_tb_func(output) if traceback is not None: for tb_step in self.continue_sdk.config.on_traceback: - step = get_step_from_name( - tb_step.step_name, {"output": output, **tb_step.params}) + step = tb_step.step({"output": output, **tb_step.params}) await self._run_singular_step(step) _highlighted_ranges: List[HighlightedRangeContext] = [] diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 70c4876e..54f15143 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,14 +1,15 @@ import json import os +from .main import Step from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict +from typing import List, Literal, Optional, Dict, Type, Union import yaml class SlashCommand(BaseModel): name: str description: str - step_name: str + step: Type[Step] params: Optional[Dict] = {} @@ -19,54 +20,10 @@ class CustomCommand(BaseModel): class OnTracebackSteps(BaseModel): - step_name: str + step: Type[Step] params: Optional[Dict] = {} -DEFAULT_SLASH_COMMANDS = [ - # SlashCommand( - # name="pytest", - # description="Write pytest unit tests for the current file", - # step_name="WritePytestsRecipe", - # params=??) - SlashCommand( - name="edit", - description="Edit code in the current file or the highlighted code", - step_name="EditHighlightedCodeStep", - ), - # SlashCommand( - # name="explain", - # description="Reply to instructions or a question with previous steps and the highlighted code or current file as context", - # step_name="SimpleChatStep", - # ), - SlashCommand( - name="config", - description="Open the config file to create new and edit existing slash commands", - step_name="OpenConfigStep", - ), - SlashCommand( - name="help", - description="Ask a question like '/help what is given to the llm as context?'", - step_name="HelpStep", - ), - SlashCommand( - name="comment", - description="Write comments for the current file or highlighted code", - step_name="CommentCodeStep", - ), - SlashCommand( - name="feedback", - description="Send feedback to improve Continue", - step_name="FeedbackStep", - ), - SlashCommand( - name="clear", - description="Clear step history", - step_name="ClearHistoryStep", - ) -] - - class AzureInfo(BaseModel): endpoint: str engine: str @@ -77,7 +34,7 @@ class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. """ - steps_on_startup: Optional[Dict[str, Dict]] = {} + steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", @@ -88,16 +45,49 @@ class ContinueConfig(BaseModel): description="This is an example custom command. Use /config to edit it and create more", prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", )] - slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS - on_traceback: Optional[List[OnTracebackSteps]] = [ - OnTracebackSteps(step_name="DefaultOnTracebackStep")] + slash_commands: Optional[List[SlashCommand]] = [] + on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None azure_openai_info: Optional[AzureInfo] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): - return DEFAULT_SLASH_COMMANDS + from ..steps.open_config import OpenConfigStep + from ..steps.clear_history import ClearHistoryStep + from ..steps.feedback import FeedbackStep + from ..steps.comment_code import CommentCodeStep + from ..steps.main import EditHighlightedCodeStep + + DEFAULT_SLASH_COMMANDS = [ + SlashCommand( + name="edit", + description="Edit code in the current file or the highlighted code", + step=EditHighlightedCodeStep, + ), + SlashCommand( + name="config", + description="Open the config file to create new and edit existing slash commands", + step=OpenConfigStep, + ), + SlashCommand( + name="comment", + description="Write comments for the current file or highlighted code", + step=CommentCodeStep, + ), + SlashCommand( + name="feedback", + description="Send feedback to improve Continue", + step=FeedbackStep, + ), + SlashCommand( + name="clear", + description="Clear step history", + step=ClearHistoryStep, + ) + ] + + return DEFAULT_SLASH_COMMANDS + v @validator('temperature', pre=True) def temperature_validator(cls, v): diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 1000f0f4..53e482fa 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -8,7 +8,6 @@ from ..steps.steps_on_startup import StepsOnStartupStep from .main import Step, History, Policy from .observation import UserInputObservation from ..steps.core.core import MessageStep -from ..libs.util.step_name_to_steps import get_step_from_name from ..steps.custom_command import CustomCommandStep @@ -24,7 +23,11 @@ def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: if slash_command.name == command_name[1:]: params = slash_command.params params["user_input"] = after_command - return get_step_from_name(slash_command.step_name, params) + try: + return slash_command.step(**params) + except TypeError as e: + raise Exception( + f"Incorrect params used for slash command '{command_name}': {e}") return None @@ -69,6 +72,9 @@ class DemoPolicy(Policy): if custom_command is not None: return custom_command + if user_input.startswith("/edit"): + return EditHighlightedCodeStep(user_input=user_input[5:]) + return SimpleChatStep() return None diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 37a51efa..4100efa6 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -144,20 +144,20 @@ class ContinueSDK(AbstractContinueSDK): ide: AbstractIdeProtocolServer models: Models context: Context + config: ContinueConfig __autopilot: Autopilot def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot self.context = autopilot.context - self.config = self._load_config() @classmethod async def create(cls, autopilot: Autopilot) -> "ContinueSDK": sdk = ContinueSDK(autopilot) try: - config = sdk._load_config() + config = sdk._load_config_dot_py() sdk.config = config except Exception as e: print(e) @@ -175,19 +175,6 @@ class ContinueSDK(AbstractContinueSDK): sdk.models = await Models.create(sdk) return sdk - config: ContinueConfig - - def _load_config(self) -> ContinueConfig: - dir = self.ide.workspace_directory - yaml_path = os.path.join(dir, '.continue', 'config.yaml') - json_path = os.path.join(dir, '.continue', 'config.json') - if os.path.exists(yaml_path): - return load_config(yaml_path) - elif os.path.exists(json_path): - return load_config(json_path) - else: - return load_global_config() - @property def history(self) -> History: return self.__autopilot.history @@ -267,6 +254,22 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) + _last_valid_config: ContinueConfig = None + + def _load_config_dot_py(self) -> ContinueConfig: + # Use importlib to load the config file config.py at the given path + path = os.path.join(os.path.expanduser("~"), ".continue", "config.py") + try: + import importlib.util + spec = importlib.util.spec_from_file_location("config", path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + self._last_valid_config = config.config + return config.config + except Exception as e: + print("Error loading config.py: ", e) + return ContinueConfig() if self._last_valid_config is None else self._last_valid_config + def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) ) if only_editing else self.__autopilot._highlighted_ranges diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py deleted file mode 100644 index 49056c81..00000000 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Dict - -from ...core.main import Step -from ...steps.core.core import UserInputStep -from ...steps.main import EditHighlightedCodeStep -from ...steps.chat import SimpleChatStep -from ...steps.comment_code import CommentCodeStep -from ...steps.feedback import FeedbackStep -from ...recipes.AddTransformRecipe.main import AddTransformRecipe -from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ...steps.on_traceback import DefaultOnTracebackStep -from ...steps.clear_history import ClearHistoryStep -from ...steps.open_config import OpenConfigStep -from ...steps.help import HelpStep - -# This mapping is used to convert from string in ContinueConfig json to corresponding Step class. -# Used for example in slash_commands and steps_on_startup -step_name_to_step_class = { - "UserInputStep": UserInputStep, - "EditHighlightedCodeStep": EditHighlightedCodeStep, - "SimpleChatStep": SimpleChatStep, - "CommentCodeStep": CommentCodeStep, - "FeedbackStep": FeedbackStep, - "AddTransformRecipe": AddTransformRecipe, - "CreatePipelineRecipe": CreatePipelineRecipe, - "DDtoBQRecipe": DDtoBQRecipe, - "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, - "DefaultOnTracebackStep": DefaultOnTracebackStep, - "ClearHistoryStep": ClearHistoryStep, - "OpenConfigStep": OpenConfigStep, - "HelpStep": HelpStep, -} - - -def get_step_from_name(step_name: str, params: Dict) -> Step: - try: - return step_name_to_step_class[step_name](**params) - except: - print( - f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") - raise diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py index 365cbe1a..318c28df 100644 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/steps/steps_on_startup.py @@ -6,7 +6,6 @@ from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.AddTransformRecipe.main import AddTransformRecipe -from ..libs.util.step_name_to_steps import get_step_from_name class StepsOnStartupStep(Step): @@ -18,6 +17,6 @@ class StepsOnStartupStep(Step): async def run(self, sdk: ContinueSDK): steps_on_startup = sdk.config.steps_on_startup - for step_name, step_params in steps_on_startup.items(): - step = get_step_from_name(step_name, step_params) + for step_type in steps_on_startup: + step = step_type() await sdk.run_step(step) |