summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py4
-rw-r--r--continuedev/src/continuedev/core/autopilot.py9
-rw-r--r--continuedev/src/continuedev/core/config.py34
-rw-r--r--continuedev/src/continuedev/core/main.py6
-rw-r--r--continuedev/src/continuedev/core/policy.py73
-rw-r--r--continuedev/src/continuedev/core/sdk.py6
-rw-r--r--continuedev/src/continuedev/libs/util/step_name_to_steps.py27
-rw-r--r--continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py3
-rw-r--r--continuedev/src/continuedev/server/gui.py7
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py6
-rw-r--r--continuedev/src/continuedev/steps/main.py4
-rw-r--r--continuedev/src/continuedev/steps/steps_on_startup.py24
13 files changed, 109 insertions, 96 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 3b85708d..0658f1b8 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -76,8 +76,8 @@ class AbstractContinueSDK(ABC):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
pass
- @abstractmethod
- async def get_config(self) -> ContinueConfig:
+ @abstractproperty
+ def config(self) -> ContinueConfig:
pass
@abstractmethod
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 1642003c..0874bbc5 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -40,6 +40,9 @@ class Autopilot(ContinueBaseModel):
def get_full_state(self) -> FullState:
return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue)
+ async def get_available_slash_commands(self) -> List[Dict]:
+ return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or []
+
async def clear_history(self):
self.history = History.from_empty()
self._main_user_input_queue = []
@@ -202,7 +205,7 @@ class Autopilot(ContinueBaseModel):
await self._run_singular_step(next_step, is_future_step)
- if next_step := self.policy.next(self.history):
+ if next_step := self.policy.next(self.continue_sdk.config, self.history):
is_future_step = False
elif next_step := self.history.take_next_step():
is_future_step = True
@@ -215,11 +218,11 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
async def run_from_observation(self, observation: Observation):
- next_step = self.policy.next(self.history)
+ next_step = self.policy.next(self.continue_sdk.config, self.history)
await self.run_from_step(next_step)
async def run_policy(self):
- first_step = self.policy.next(self.history)
+ first_step = self.policy.next(self.continue_sdk.config, self.history)
await self.run_from_step(first_step)
async def _request_halt(self):
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 8ed41a82..cf723984 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,9 +1,18 @@
import json
import os
-from pydantic import BaseModel
+from pydantic import BaseModel, validator
from typing import List, Optional, Dict
import yaml
+from .main import Step
+
+
+class SlashCommand(BaseModel):
+ name: str
+ description: str
+ step_name: str
+ params: Optional[Dict] = {}
+
class ContinueConfig(BaseModel):
"""
@@ -12,6 +21,29 @@ class ContinueConfig(BaseModel):
steps_on_startup: Optional[Dict[str, Dict]] = {}
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
+ slash_commands: Optional[List[SlashCommand]] = [
+ # SlashCommand(
+ # name="pytest",
+ # description="Write pytest unit tests for the current file",
+ # step_name="WritePytestsRecipe",
+ # params=??)
+
+ SlashCommand(
+ name="dlt",
+ description="Create a dlt pipeline",
+ step_name="CreatePipelineRecipe",
+ ),
+ SlashCommand(
+ name="ddtobq",
+ description="Create a dlt pipeline to load data from a data source into BigQuery",
+ step_name="DDtoBQRecipe",
+ ),
+ SlashCommand(
+ name="deployairflow",
+ description="Deploy a dlt pipeline to Airflow",
+ step_name="DeployPipelineAirflowRecipe",
+ ),
+ ]
def load_config(config_file: str) -> ContinueConfig:
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 81aaaf2e..f6b26d69 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -118,11 +118,15 @@ class Models:
pass
+class ContinueConfig:
+ pass
+
+
class Policy(ContinueBaseModel):
"""A rule that determines which step to take next"""
# Note that history is mutable, kinda sus
- def next(self, history: History = History.from_empty()) -> "Step":
+ def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step":
raise NotImplementedError
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 00b5427c..37a10e36 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,5 +1,6 @@
from typing import List, Tuple, Type
+from .config import ContinueConfig
from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma
from ..steps.steps_on_startup import StepsOnStartupStep
from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
@@ -15,12 +16,13 @@ from ..steps.react import NLDecisionStep
from ..steps.chat import SimpleChatStep
from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
from ..steps.core.core import MessageStep
+from ..libs.util.step_name_to_steps import get_step_from_name
class DemoPolicy(Policy):
ran_code_last: bool = False
- def next(self, history: History) -> Step:
+ def next(self, config: ContinueConfig, history: History) -> Step:
# At the very start, run initial Steps spcecified in the config
if history.get_current() is None:
return (
@@ -33,20 +35,18 @@ class DemoPolicy(Policy):
if observation is not None and isinstance(observation, UserInputObservation):
# This could be defined with ObservationTypePolicy. Ergonomics not right though.
user_input = observation.user_input
+
+ if user_input.startswith("/"):
+ command_name = user_input.split(" ")[0]
+ after_command = " ".join(user_input.split(" ")[1:])
+ for slash_command in config.slash_commands:
+ if slash_command.name == command_name[1:]:
+ return get_step_from_name(slash_command.step_name, slash_command.params)
+
if "/pytest" in user_input.lower():
return WritePytestsRecipe(instructions=user_input)
- elif "/dlt" in user_input.lower() or " dlt" in user_input.lower():
- return CreatePipelineRecipe()
if "/pytest" in observation.user_input.lower():
return WritePytestsRecipe(instructions=observation.user_input)
- elif "/dlt" in observation.user_input.lower():
- return CreatePipelineRecipe()
- elif "/ddtobq" in observation.user_input.lower():
- return DDtoBQRecipe()
- elif "/airflow" in observation.user_input.lower():
- return DeployPipelineAirflowRecipe()
- elif "/transform" in observation.user_input.lower():
- return AddTransformRecipe()
elif "/comment" in observation.user_input.lower():
return CommentCodeStep()
elif "/ask" in user_input:
@@ -72,54 +72,3 @@ class DemoPolicy(Policy):
return SolveTracebackStep(traceback=observation.traceback)
else:
return None
-
-
-class ObservationTypePolicy(Policy):
- def __init__(self, base_policy: Policy, observation_type: Type[Observation], step_type: Type[Step]):
- self.observation_type = observation_type
- self.step_type = step_type
- self.base_policy = base_policy
-
- def next(self, history: History) -> Step:
- observation = history.last_observation()
- if observation is not None and isinstance(observation, self.observation_type):
- return self.step_type(observation)
- return self.base_policy.next(history)
-
-
-class PolicyWrappedWithValidators(Policy):
- """Default is to stop, unless the validator tells what to do next"""
- index: int
- stage: int
-
- def __init__(self, base_policy: Policy, pairs: List[Tuple[Validator, Type[Step]]]):
- # Want to pass Type[Validator], or just the Validator? Question of where params are coming from.
- self.pairs = pairs
- self.index = len(pairs)
- self.validating = 0
- self.base_policy = base_policy
-
- def next(self, history: History) -> Step:
- if self.index == len(self.pairs):
- self.index = 0
- return self.base_policy.next(history)
-
- if self.stage == 0:
- # Running the validator at the current index for the first time
- validator, step = self.pairs[self.index]
- self.stage = 1
- return validator
- elif self.stage == 1:
- # Previously ran the validator at the current index, now receiving its ValidatorObservation
- observation = history.last_observation()
- if observation.passed:
- self.stage = 0
- self.index += 1
- if self.index == len(self.pairs):
- self.index = 0
- return self.base_policy.next(history)
- else:
- return self.pairs[self.index][0]
- else:
- _, step_type = self.pairs[self.index]
- return step_type(observation)
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 2849b0c8..1f4cdfb2 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,4 +1,3 @@
-from abc import ABC, abstractmethod
import asyncio
from functools import cached_property
from typing import Coroutine, Union
@@ -119,8 +118,9 @@ class ContinueSDK(AbstractContinueSDK):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
- async def get_config(self) -> ContinueConfig:
- dir = await self.ide.getWorkspaceDirectory()
+ @property
+ def config(self) -> ContinueConfig:
+ dir = self.ide.workspace_directory
yaml_path = os.path.join(dir, '.continue', 'config.yaml')
json_path = os.path.join(dir, '.continue', 'config.json')
if os.path.exists(yaml_path):
diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
new file mode 100644
index 00000000..4023b73b
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
@@ -0,0 +1,27 @@
+from typing import Dict
+
+from ...core.main import Step
+from ...steps.core.core import UserInputStep
+from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
+from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe
+from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe
+from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe
+from ...recipes.AddTransformRecipe.main import AddTransformRecipe
+
+step_name_to_step_class = {
+ "UserInputStep": UserInputStep,
+ "CreatePipelineRecipe": CreatePipelineRecipe,
+ "DDtoBQRecipe": DDtoBQRecipe,
+ "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe,
+ "AddTransformRecipe": AddTransformRecipe,
+ "DDtoBQRecipe": DDtoBQRecipe
+}
+
+
+def get_step_from_name(step_name: str, params: Dict) -> Step:
+ try:
+ return step_name_to_step_class[step_name](**params)
+ except:
+ print(
+ f"Incorrect parameters for step {step_name}. Parameters provided were: {params}")
+ raise
diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
index 818168ba..92bddc98 100644
--- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
@@ -1,7 +1,7 @@
from textwrap import dedent
-from ...core.main import Step
from ...core.sdk import ContinueSDK
+from ...core.main import Step
from ...steps.core.core import WaitForUserInputStep
from ...steps.core.core import MessageStep
from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep
diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
index e59cc51c..ea4607da 100644
--- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
@@ -6,11 +6,10 @@ import time
from ...models.main import Range
from ...models.filesystem import RangeInFile
from ...steps.core.core import MessageStep
-from ...core.sdk import Models
from ...core.observation import DictObservation, InternalErrorObservation
from ...models.filesystem_edit import AddFile, FileEdit
from ...core.main import Step
-from ...core.sdk import ContinueSDK
+from ...core.sdk import ContinueSDK, Models
AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)"
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index e8b52004..cf046734 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -90,6 +90,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
"state": state
})
+ async def send_available_slash_commands(self):
+ commands = await self.session.autopilot.get_available_slash_commands()
+ await self._send_json("available_slash_commands", {
+ "commands": commands
+ })
+
def on_main_input(self, input: str):
# Do something with user input
asyncio.create_task(self.session.autopilot.accept_user_input(input))
@@ -127,6 +133,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.websocket = websocket
# Update any history that may have happened before connection
+ await protocol.send_available_slash_commands()
await protocol.send_state_update()
while AppStatus.should_exit is False:
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
index 889c6761..d9506c6f 100644
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ b/continuedev/src/continuedev/server/gui_protocol.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, Dict, List
from abc import ABC, abstractmethod
@@ -28,6 +28,10 @@ class AbstractGUIProtocolServer(ABC):
"""Send a state update to the client"""
@abstractmethod
+ async def send_available_slash_commands(self, commands: List[Dict]):
+ """Send a list of available slash commands to the client"""
+
+ @abstractmethod
def on_retry_at_index(self, index: int):
"""Called when the user requests a retry at a previous index"""
diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py
index 9634c726..36e4f519 100644
--- a/continuedev/src/continuedev/steps/main.py
+++ b/continuedev/src/continuedev/steps/main.py
@@ -63,10 +63,10 @@ class RunPolicyUntilDoneStep(Step):
policy: "Policy"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- next_step = self.policy.next(sdk.history)
+ next_step = self.policy.next(sdk.config, sdk.history)
while next_step is not None:
observation = await sdk.run_step(next_step)
- next_step = self.policy.next(sdk.history)
+ next_step = self.policy.next(sdk.config, sdk.history)
return observation
diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py
index eae8b558..365cbe1a 100644
--- a/continuedev/src/continuedev/steps/steps_on_startup.py
+++ b/continuedev/src/continuedev/steps/steps_on_startup.py
@@ -1,19 +1,12 @@
-from ..core.main import ContinueSDK, Models, Step
+from ..core.main import Step
+from ..core.sdk import Models, ContinueSDK
from .main import UserInputStep
from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe
from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe
from ..recipes.AddTransformRecipe.main import AddTransformRecipe
-
-step_name_to_step_class = {
- "UserInputStep": UserInputStep,
- "CreatePipelineRecipe": CreatePipelineRecipe,
- "DDtoBQRecipe": DDtoBQRecipe,
- "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe,
- "AddTransformRecipe": AddTransformRecipe,
- "DDtoBQRecipe": DDtoBQRecipe
-}
+from ..libs.util.step_name_to_steps import get_step_from_name
class StepsOnStartupStep(Step):
@@ -23,13 +16,8 @@ class StepsOnStartupStep(Step):
return "Running steps on startup"
async def run(self, sdk: ContinueSDK):
- steps_descriptions = (await sdk.get_config()).steps_on_startup
+ steps_on_startup = sdk.config.steps_on_startup
- for step_name, step_params in steps_descriptions.items():
- try:
- step = step_name_to_step_class[step_name](**step_params)
- except:
- print(
- f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}")
- continue
+ for step_name, step_params in steps_on_startup.items():
+ step = get_step_from_name(step_name, step_params)
await sdk.run_step(step)