diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 34 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 73 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/step_name_to_steps.py | 27 | ||||
-rw-r--r-- | continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/main.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/steps_on_startup.py | 24 |
13 files changed, 109 insertions, 96 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 3b85708d..0658f1b8 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,8 +76,8 @@ class AbstractContinueSDK(ABC): async def get_user_secret(self, env_var: str, prompt: str) -> str: pass - @abstractmethod - async def get_config(self) -> ContinueConfig: + @abstractproperty + def config(self) -> ContinueConfig: pass @abstractmethod diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 1642003c..0874bbc5 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -40,6 +40,9 @@ class Autopilot(ContinueBaseModel): def get_full_state(self) -> FullState: return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue) + async def get_available_slash_commands(self) -> List[Dict]: + return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] + async def clear_history(self): self.history = History.from_empty() self._main_user_input_queue = [] @@ -202,7 +205,7 @@ class Autopilot(ContinueBaseModel): await self._run_singular_step(next_step, is_future_step) - if next_step := self.policy.next(self.history): + if next_step := self.policy.next(self.continue_sdk.config, self.history): is_future_step = False elif next_step := self.history.take_next_step(): is_future_step = True @@ -215,11 +218,11 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() async def run_from_observation(self, observation: Observation): - next_step = self.policy.next(self.history) + next_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(next_step) async def run_policy(self): - first_step = self.policy.next(self.history) + first_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(first_step) async def _request_halt(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 8ed41a82..cf723984 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,9 +1,18 @@ import json import os -from pydantic import BaseModel +from pydantic import BaseModel, validator from typing import List, Optional, Dict import yaml +from .main import Step + + +class SlashCommand(BaseModel): + name: str + description: str + step_name: str + params: Optional[Dict] = {} + class ContinueConfig(BaseModel): """ @@ -12,6 +21,29 @@ class ContinueConfig(BaseModel): steps_on_startup: Optional[Dict[str, Dict]] = {} server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True + slash_commands: Optional[List[SlashCommand]] = [ + # SlashCommand( + # name="pytest", + # description="Write pytest unit tests for the current file", + # step_name="WritePytestsRecipe", + # params=??) + + SlashCommand( + name="dlt", + description="Create a dlt pipeline", + step_name="CreatePipelineRecipe", + ), + SlashCommand( + name="ddtobq", + description="Create a dlt pipeline to load data from a data source into BigQuery", + step_name="DDtoBQRecipe", + ), + SlashCommand( + name="deployairflow", + description="Deploy a dlt pipeline to Airflow", + step_name="DeployPipelineAirflowRecipe", + ), + ] def load_config(config_file: str) -> ContinueConfig: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 81aaaf2e..f6b26d69 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -118,11 +118,15 @@ class Models: pass +class ContinueConfig: + pass + + class Policy(ContinueBaseModel): """A rule that determines which step to take next""" # Note that history is mutable, kinda sus - def next(self, history: History = History.from_empty()) -> "Step": + def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step": raise NotImplementedError diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 00b5427c..37a10e36 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,5 +1,6 @@ from typing import List, Tuple, Type +from .config import ContinueConfig from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma from ..steps.steps_on_startup import StepsOnStartupStep from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe @@ -15,12 +16,13 @@ from ..steps.react import NLDecisionStep from ..steps.chat import SimpleChatStep from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..steps.core.core import MessageStep +from ..libs.util.step_name_to_steps import get_step_from_name class DemoPolicy(Policy): ran_code_last: bool = False - def next(self, history: History) -> Step: + def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps spcecified in the config if history.get_current() is None: return ( @@ -33,20 +35,18 @@ class DemoPolicy(Policy): if observation is not None and isinstance(observation, UserInputObservation): # This could be defined with ObservationTypePolicy. Ergonomics not right though. user_input = observation.user_input + + if user_input.startswith("/"): + command_name = user_input.split(" ")[0] + after_command = " ".join(user_input.split(" ")[1:]) + for slash_command in config.slash_commands: + if slash_command.name == command_name[1:]: + return get_step_from_name(slash_command.step_name, slash_command.params) + if "/pytest" in user_input.lower(): return WritePytestsRecipe(instructions=user_input) - elif "/dlt" in user_input.lower() or " dlt" in user_input.lower(): - return CreatePipelineRecipe() if "/pytest" in observation.user_input.lower(): return WritePytestsRecipe(instructions=observation.user_input) - elif "/dlt" in observation.user_input.lower(): - return CreatePipelineRecipe() - elif "/ddtobq" in observation.user_input.lower(): - return DDtoBQRecipe() - elif "/airflow" in observation.user_input.lower(): - return DeployPipelineAirflowRecipe() - elif "/transform" in observation.user_input.lower(): - return AddTransformRecipe() elif "/comment" in observation.user_input.lower(): return CommentCodeStep() elif "/ask" in user_input: @@ -72,54 +72,3 @@ class DemoPolicy(Policy): return SolveTracebackStep(traceback=observation.traceback) else: return None - - -class ObservationTypePolicy(Policy): - def __init__(self, base_policy: Policy, observation_type: Type[Observation], step_type: Type[Step]): - self.observation_type = observation_type - self.step_type = step_type - self.base_policy = base_policy - - def next(self, history: History) -> Step: - observation = history.last_observation() - if observation is not None and isinstance(observation, self.observation_type): - return self.step_type(observation) - return self.base_policy.next(history) - - -class PolicyWrappedWithValidators(Policy): - """Default is to stop, unless the validator tells what to do next""" - index: int - stage: int - - def __init__(self, base_policy: Policy, pairs: List[Tuple[Validator, Type[Step]]]): - # Want to pass Type[Validator], or just the Validator? Question of where params are coming from. - self.pairs = pairs - self.index = len(pairs) - self.validating = 0 - self.base_policy = base_policy - - def next(self, history: History) -> Step: - if self.index == len(self.pairs): - self.index = 0 - return self.base_policy.next(history) - - if self.stage == 0: - # Running the validator at the current index for the first time - validator, step = self.pairs[self.index] - self.stage = 1 - return validator - elif self.stage == 1: - # Previously ran the validator at the current index, now receiving its ValidatorObservation - observation = history.last_observation() - if observation.passed: - self.stage = 0 - self.index += 1 - if self.index == len(self.pairs): - self.index = 0 - return self.base_policy.next(history) - else: - return self.pairs[self.index][0] - else: - _, step_type = self.pairs[self.index] - return step_type(observation) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 2849b0c8..1f4cdfb2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import asyncio from functools import cached_property from typing import Coroutine, Union @@ -119,8 +118,9 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) - async def get_config(self) -> ContinueConfig: - dir = await self.ide.getWorkspaceDirectory() + @property + def config(self) -> ContinueConfig: + dir = self.ide.workspace_directory yaml_path = os.path.join(dir, '.continue', 'config.yaml') json_path = os.path.join(dir, '.continue', 'config.json') if os.path.exists(yaml_path): diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py new file mode 100644 index 00000000..4023b73b --- /dev/null +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -0,0 +1,27 @@ +from typing import Dict + +from ...core.main import Step +from ...steps.core.core import UserInputStep +from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.AddTransformRecipe.main import AddTransformRecipe + +step_name_to_step_class = { + "UserInputStep": UserInputStep, + "CreatePipelineRecipe": CreatePipelineRecipe, + "DDtoBQRecipe": DDtoBQRecipe, + "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, + "AddTransformRecipe": AddTransformRecipe, + "DDtoBQRecipe": DDtoBQRecipe +} + + +def get_step_from_name(step_name: str, params: Dict) -> Step: + try: + return step_name_to_step_class[step_name](**params) + except: + print( + f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") + raise diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py index 818168ba..92bddc98 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py @@ -1,7 +1,7 @@ from textwrap import dedent -from ...core.main import Step from ...core.sdk import ContinueSDK +from ...core.main import Step from ...steps.core.core import WaitForUserInputStep from ...steps.core.core import MessageStep from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py index e59cc51c..ea4607da 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py @@ -6,11 +6,10 @@ import time from ...models.main import Range from ...models.filesystem import RangeInFile from ...steps.core.core import MessageStep -from ...core.sdk import Models from ...core.observation import DictObservation, InternalErrorObservation from ...models.filesystem_edit import AddFile, FileEdit from ...core.main import Step -from ...core.sdk import ContinueSDK +from ...core.sdk import ContinueSDK, Models AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index e8b52004..cf046734 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -90,6 +90,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): "state": state }) + async def send_available_slash_commands(self): + commands = await self.session.autopilot.get_available_slash_commands() + await self._send_json("available_slash_commands", { + "commands": commands + }) + def on_main_input(self, input: str): # Do something with user input asyncio.create_task(self.session.autopilot.accept_user_input(input)) @@ -127,6 +133,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.websocket = websocket # Update any history that may have happened before connection + await protocol.send_available_slash_commands() await protocol.send_state_update() while AppStatus.should_exit is False: diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 889c6761..d9506c6f 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, List from abc import ABC, abstractmethod @@ -28,6 +28,10 @@ class AbstractGUIProtocolServer(ABC): """Send a state update to the client""" @abstractmethod + async def send_available_slash_commands(self, commands: List[Dict]): + """Send a list of available slash commands to the client""" + + @abstractmethod def on_retry_at_index(self, index: int): """Called when the user requests a retry at a previous index""" diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 9634c726..36e4f519 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -63,10 +63,10 @@ class RunPolicyUntilDoneStep(Step): policy: "Policy" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - next_step = self.policy.next(sdk.history) + next_step = self.policy.next(sdk.config, sdk.history) while next_step is not None: observation = await sdk.run_step(next_step) - next_step = self.policy.next(sdk.history) + next_step = self.policy.next(sdk.config, sdk.history) return observation diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py index eae8b558..365cbe1a 100644 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/steps/steps_on_startup.py @@ -1,19 +1,12 @@ -from ..core.main import ContinueSDK, Models, Step +from ..core.main import Step +from ..core.sdk import Models, ContinueSDK from .main import UserInputStep from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe from ..recipes.AddTransformRecipe.main import AddTransformRecipe - -step_name_to_step_class = { - "UserInputStep": UserInputStep, - "CreatePipelineRecipe": CreatePipelineRecipe, - "DDtoBQRecipe": DDtoBQRecipe, - "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, - "AddTransformRecipe": AddTransformRecipe, - "DDtoBQRecipe": DDtoBQRecipe -} +from ..libs.util.step_name_to_steps import get_step_from_name class StepsOnStartupStep(Step): @@ -23,13 +16,8 @@ class StepsOnStartupStep(Step): return "Running steps on startup" async def run(self, sdk: ContinueSDK): - steps_descriptions = (await sdk.get_config()).steps_on_startup + steps_on_startup = sdk.config.steps_on_startup - for step_name, step_params in steps_descriptions.items(): - try: - step = step_name_to_step_class[step_name](**step_params) - except: - print( - f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") - continue + for step_name, step_params in steps_on_startup.items(): + step = get_step_from_name(step_name, step_params) await sdk.run_step(step) |