diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-29 22:40:04 -0700 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-29 22:40:04 -0700 |
| commit | c6a12550ffca1ffe35630e7aa9af6913ddbe0675 (patch) | |
| tree | 6450574104c42c76ef168c54a92af9ecbb7337c5 /continuedev/src/continuedev/core | |
| parent | 17566c66e0a01ad3c38ece974e44c1c71a9188de (diff) | |
| download | sncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.tar.gz sncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.tar.bz2 sncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.zip | |
feat: :sparkles: EmbeddingContextProvider
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 24 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 80 |
3 files changed, 17 insertions, 98 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3f25e64e..12339f9b 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str: class Autopilot(ContinueBaseModel): - policy: Policy ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None @@ -64,20 +66,18 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @classmethod - async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": - autopilot = cls(ide=ide, policy=policy) - autopilot.continue_sdk = await ContinueSDK.create(autopilot) + async def start(self): + self.continue_sdk = await ContinueSDK.create(self) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy # Load documents into the search index - autopilot.context_manager = await ContextManager.create( - autopilot.continue_sdk.config.context_providers + [ - HighlightedCodeContextProvider(ide=ide), - FileContextProvider(workspace_dir=ide.workspace_directory) + self.context_manager = await ContextManager.create( + self.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) - await autopilot.context_manager.load_index() - - return autopilot + await self.context_manager.load_index() class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 9fbda824..fe0946cd 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,10 +1,8 @@ -import json -import os -from .main import Step -from .context import ContextProvider from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider class SlashCommand(BaseModel): @@ -51,6 +49,7 @@ class ContinueConfig(BaseModel): on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None openai_server_info: Optional[OpenAIServerInfo] = None + policy_override: Optional[Policy] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py deleted file mode 100644 index d90177b5..00000000 --- a/continuedev/src/continuedev/core/policy.py +++ /dev/null @@ -1,80 +0,0 @@ -from textwrap import dedent -from typing import Union - -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep - - -def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - """ - Parses a slash command, returning the command name and the rest of the input. - """ - if inp.startswith("/"): - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - - for slash_command in config.slash_commands: - if slash_command.name == command_name[1:]: - params = slash_command.params - params["user_input"] = after_command - try: - return slash_command.step(**params) - except TypeError as e: - raise Exception( - f"Incorrect params used for slash command '{command_name}': {e}") - return None - - -def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - for custom_cmd in config.custom_commands: - if custom_cmd.name == command_name[1:]: - slash_command = parse_slash_command(custom_cmd.prompt, config) - if slash_command is not None: - return slash_command - return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) - return None - - -class DefaultPolicy(Policy): - ran_code_last: bool = False - - def next(self, config: ContinueConfig, history: History) -> Step: - # At the very start, run initial Steps spcecified in the config - if history.get_current() is None: - return ( - MessageStep(name="Welcome to Continue", message=dedent("""\ - - Highlight code section and ask a question or give instructions - - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - - Use `/help` to ask questions about how to use Continue""")) >> - WelcomeStep() >> - # CreateCodebaseIndexChroma() >> - StepsOnStartupStep()) - - observation = history.get_current().observation - if observation is not None and isinstance(observation, UserInputObservation): - # This could be defined with ObservationTypePolicy. Ergonomics not right though. - user_input = observation.user_input - - slash_command = parse_slash_command(user_input, config) - if slash_command is not None: - return slash_command - - custom_command = parse_custom_command(user_input, config) - if custom_command is not None: - return custom_command - - if user_input.startswith("/edit"): - return EditHighlightedCodeStep(user_input=user_input[5:]) - - return SimpleChatStep() - - return None |
