diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 34 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 73 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/step_name_to_steps.py | 27 | ||||
| -rw-r--r-- | continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/main.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/steps_on_startup.py | 24 | 
13 files changed, 109 insertions, 96 deletions
| diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 3b85708d..0658f1b8 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,8 +76,8 @@ class AbstractContinueSDK(ABC):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          pass -    @abstractmethod -    async def get_config(self) -> ContinueConfig: +    @abstractproperty +    def config(self) -> ContinueConfig:          pass      @abstractmethod diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 1642003c..0874bbc5 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -40,6 +40,9 @@ class Autopilot(ContinueBaseModel):      def get_full_state(self) -> FullState:          return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue) +    async def get_available_slash_commands(self) -> List[Dict]: +        return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] +      async def clear_history(self):          self.history = History.from_empty()          self._main_user_input_queue = [] @@ -202,7 +205,7 @@ class Autopilot(ContinueBaseModel):              await self._run_singular_step(next_step, is_future_step) -            if next_step := self.policy.next(self.history): +            if next_step := self.policy.next(self.continue_sdk.config, self.history):                  is_future_step = False              elif next_step := self.history.take_next_step():                  is_future_step = True @@ -215,11 +218,11 @@ class Autopilot(ContinueBaseModel):          await self.update_subscribers()      async def run_from_observation(self, observation: Observation): -        next_step = self.policy.next(self.history) +        next_step = self.policy.next(self.continue_sdk.config, self.history)          await self.run_from_step(next_step)      async def run_policy(self): -        first_step = self.policy.next(self.history) +        first_step = self.policy.next(self.continue_sdk.config, self.history)          await self.run_from_step(first_step)      async def _request_halt(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 8ed41a82..cf723984 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,9 +1,18 @@  import json  import os -from pydantic import BaseModel +from pydantic import BaseModel, validator  from typing import List, Optional, Dict  import yaml +from .main import Step + + +class SlashCommand(BaseModel): +    name: str +    description: str +    step_name: str +    params: Optional[Dict] = {} +  class ContinueConfig(BaseModel):      """ @@ -12,6 +21,29 @@ class ContinueConfig(BaseModel):      steps_on_startup: Optional[Dict[str, Dict]] = {}      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True +    slash_commands: Optional[List[SlashCommand]] = [ +        # SlashCommand( +        #     name="pytest", +        #     description="Write pytest unit tests for the current file", +        #     step_name="WritePytestsRecipe", +        #     params=??) + +        SlashCommand( +            name="dlt", +            description="Create a dlt pipeline", +            step_name="CreatePipelineRecipe", +        ), +        SlashCommand( +            name="ddtobq", +            description="Create a dlt pipeline to load data from a data source into BigQuery", +            step_name="DDtoBQRecipe", +        ), +        SlashCommand( +            name="deployairflow", +            description="Deploy a dlt pipeline to Airflow", +            step_name="DeployPipelineAirflowRecipe", +        ), +    ]  def load_config(config_file: str) -> ContinueConfig: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 81aaaf2e..f6b26d69 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -118,11 +118,15 @@ class Models:      pass +class ContinueConfig: +    pass + +  class Policy(ContinueBaseModel):      """A rule that determines which step to take next"""      # Note that history is mutable, kinda sus -    def next(self, history: History = History.from_empty()) -> "Step": +    def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step":          raise NotImplementedError diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 00b5427c..37a10e36 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,5 +1,6 @@  from typing import List, Tuple, Type +from .config import ContinueConfig  from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma  from ..steps.steps_on_startup import StepsOnStartupStep  from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe @@ -15,12 +16,13 @@ from ..steps.react import NLDecisionStep  from ..steps.chat import SimpleChatStep  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..steps.core.core import MessageStep +from ..libs.util.step_name_to_steps import get_step_from_name  class DemoPolicy(Policy):      ran_code_last: bool = False -    def next(self, history: History) -> Step: +    def next(self, config: ContinueConfig, history: History) -> Step:          # At the very start, run initial Steps spcecified in the config          if history.get_current() is None:              return ( @@ -33,20 +35,18 @@ class DemoPolicy(Policy):          if observation is not None and isinstance(observation, UserInputObservation):              # This could be defined with ObservationTypePolicy. Ergonomics not right though.              user_input = observation.user_input + +            if user_input.startswith("/"): +                command_name = user_input.split(" ")[0] +                after_command = " ".join(user_input.split(" ")[1:]) +                for slash_command in config.slash_commands: +                    if slash_command.name == command_name[1:]: +                        return get_step_from_name(slash_command.step_name, slash_command.params) +              if "/pytest" in user_input.lower():                  return WritePytestsRecipe(instructions=user_input) -            elif "/dlt" in user_input.lower() or " dlt" in user_input.lower(): -                return CreatePipelineRecipe()              if "/pytest" in observation.user_input.lower():                  return WritePytestsRecipe(instructions=observation.user_input) -            elif "/dlt" in observation.user_input.lower(): -                return CreatePipelineRecipe() -            elif "/ddtobq" in observation.user_input.lower(): -                return DDtoBQRecipe() -            elif "/airflow" in observation.user_input.lower(): -                return DeployPipelineAirflowRecipe() -            elif "/transform" in observation.user_input.lower(): -                return AddTransformRecipe()              elif "/comment" in observation.user_input.lower():                  return CommentCodeStep()              elif "/ask" in user_input: @@ -72,54 +72,3 @@ class DemoPolicy(Policy):              return SolveTracebackStep(traceback=observation.traceback)          else:              return None - - -class ObservationTypePolicy(Policy): -    def __init__(self, base_policy: Policy, observation_type: Type[Observation], step_type: Type[Step]): -        self.observation_type = observation_type -        self.step_type = step_type -        self.base_policy = base_policy - -    def next(self, history: History) -> Step: -        observation = history.last_observation() -        if observation is not None and isinstance(observation, self.observation_type): -            return self.step_type(observation) -        return self.base_policy.next(history) - - -class PolicyWrappedWithValidators(Policy): -    """Default is to stop, unless the validator tells what to do next""" -    index: int -    stage: int - -    def __init__(self, base_policy: Policy, pairs: List[Tuple[Validator, Type[Step]]]): -        # Want to pass Type[Validator], or just the Validator? Question of where params are coming from. -        self.pairs = pairs -        self.index = len(pairs) -        self.validating = 0 -        self.base_policy = base_policy - -    def next(self, history: History) -> Step: -        if self.index == len(self.pairs): -            self.index = 0 -            return self.base_policy.next(history) - -        if self.stage == 0: -            # Running the validator at the current index for the first time -            validator, step = self.pairs[self.index] -            self.stage = 1 -            return validator -        elif self.stage == 1: -            # Previously ran the validator at the current index, now receiving its ValidatorObservation -            observation = history.last_observation() -            if observation.passed: -                self.stage = 0 -                self.index += 1 -                if self.index == len(self.pairs): -                    self.index = 0 -                    return self.base_policy.next(history) -                else: -                    return self.pairs[self.index][0] -            else: -                _, step_type = self.pairs[self.index] -                return step_type(observation) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 2849b0c8..1f4cdfb2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod  import asyncio  from functools import cached_property  from typing import Coroutine, Union @@ -119,8 +118,9 @@ class ContinueSDK(AbstractContinueSDK):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          return await self.ide.getUserSecret(env_var) -    async def get_config(self) -> ContinueConfig: -        dir = await self.ide.getWorkspaceDirectory() +    @property +    def config(self) -> ContinueConfig: +        dir = self.ide.workspace_directory          yaml_path = os.path.join(dir, '.continue', 'config.yaml')          json_path = os.path.join(dir, '.continue', 'config.json')          if os.path.exists(yaml_path): diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py new file mode 100644 index 00000000..4023b73b --- /dev/null +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -0,0 +1,27 @@ +from typing import Dict + +from ...core.main import Step +from ...steps.core.core import UserInputStep +from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.AddTransformRecipe.main import AddTransformRecipe + +step_name_to_step_class = { +    "UserInputStep": UserInputStep, +    "CreatePipelineRecipe": CreatePipelineRecipe, +    "DDtoBQRecipe": DDtoBQRecipe, +    "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, +    "AddTransformRecipe": AddTransformRecipe, +    "DDtoBQRecipe": DDtoBQRecipe +} + + +def get_step_from_name(step_name: str, params: Dict) -> Step: +    try: +        return step_name_to_step_class[step_name](**params) +    except: +        print( +            f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") +        raise diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py index 818168ba..92bddc98 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py @@ -1,7 +1,7 @@  from textwrap import dedent -from ...core.main import Step  from ...core.sdk import ContinueSDK +from ...core.main import Step  from ...steps.core.core import WaitForUserInputStep  from ...steps.core.core import MessageStep  from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py index e59cc51c..ea4607da 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py @@ -6,11 +6,10 @@ import time  from ...models.main import Range  from ...models.filesystem import RangeInFile  from ...steps.core.core import MessageStep -from ...core.sdk import Models  from ...core.observation import DictObservation, InternalErrorObservation  from ...models.filesystem_edit import AddFile, FileEdit  from ...core.main import Step -from ...core.sdk import ContinueSDK +from ...core.sdk import ContinueSDK, Models  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index e8b52004..cf046734 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -90,6 +90,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              "state": state          }) +    async def send_available_slash_commands(self): +        commands = await self.session.autopilot.get_available_slash_commands() +        await self._send_json("available_slash_commands", { +            "commands": commands +        }) +      def on_main_input(self, input: str):          # Do something with user input          asyncio.create_task(self.session.autopilot.accept_user_input(input)) @@ -127,6 +133,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we      protocol.websocket = websocket      # Update any history that may have happened before connection +    await protocol.send_available_slash_commands()      await protocol.send_state_update()      while AppStatus.should_exit is False: diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 889c6761..d9506c6f 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, List  from abc import ABC, abstractmethod @@ -28,6 +28,10 @@ class AbstractGUIProtocolServer(ABC):          """Send a state update to the client"""      @abstractmethod +    async def send_available_slash_commands(self, commands: List[Dict]): +        """Send a list of available slash commands to the client""" + +    @abstractmethod      def on_retry_at_index(self, index: int):          """Called when the user requests a retry at a previous index""" diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 9634c726..36e4f519 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -63,10 +63,10 @@ class RunPolicyUntilDoneStep(Step):      policy: "Policy"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        next_step = self.policy.next(sdk.history) +        next_step = self.policy.next(sdk.config, sdk.history)          while next_step is not None:              observation = await sdk.run_step(next_step) -            next_step = self.policy.next(sdk.history) +            next_step = self.policy.next(sdk.config, sdk.history)          return observation diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py index eae8b558..365cbe1a 100644 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/steps/steps_on_startup.py @@ -1,19 +1,12 @@ -from ..core.main import ContinueSDK, Models, Step +from ..core.main import Step +from ..core.sdk import Models, ContinueSDK  from .main import UserInputStep  from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..recipes.AddTransformRecipe.main import AddTransformRecipe - -step_name_to_step_class = { -    "UserInputStep": UserInputStep, -    "CreatePipelineRecipe": CreatePipelineRecipe, -    "DDtoBQRecipe": DDtoBQRecipe, -    "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, -    "AddTransformRecipe": AddTransformRecipe, -    "DDtoBQRecipe": DDtoBQRecipe -} +from ..libs.util.step_name_to_steps import get_step_from_name  class StepsOnStartupStep(Step): @@ -23,13 +16,8 @@ class StepsOnStartupStep(Step):          return "Running steps on startup"      async def run(self, sdk: ContinueSDK): -        steps_descriptions = (await sdk.get_config()).steps_on_startup +        steps_on_startup = sdk.config.steps_on_startup -        for step_name, step_params in steps_descriptions.items(): -            try: -                step = step_name_to_step_class[step_name](**step_params) -            except: -                print( -                    f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") -                continue +        for step_name, step_params in steps_on_startup.items(): +            step = get_step_from_name(step_name, step_params)              await sdk.run_step(step) | 
