diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-31 00:59:42 -0700 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-31 00:59:42 -0700 |
| commit | 078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a (patch) | |
| tree | 8d047405164b9157f4e948f68d8b000a1083efdf /continuedev/src/continuedev/core | |
| parent | 72e83325a8eb5032c448a5e891c157987921ced2 (diff) | |
| parent | c51ad538deff06af6c9e5498b23e3536e18bfc4c (diff) | |
| download | sncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.tar.gz sncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.tar.bz2 sncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.zip | |
Merge branch 'main' into llm-object-config
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 23 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 78 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 |
4 files changed, 18 insertions, 92 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index b4c951b8..de95a259 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str: class Autopilot(ContinueBaseModel): - policy: Policy ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None @@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @classmethod - async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": - autopilot = cls(ide=ide, policy=policy) - autopilot.continue_sdk = await ContinueSDK.create(autopilot) + async def start(self): + self.continue_sdk = await ContinueSDK.create(self) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy # Load documents into the search index - autopilot.context_manager = await ContextManager.create( - autopilot.continue_sdk.config.context_providers + [ - HighlightedCodeContextProvider(ide=ide), - FileContextProvider(workspace_dir=ide.workspace_directory) + self.context_manager = await ContextManager.create( + self.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) - await autopilot.context_manager.load_index(ide.workspace_directory) - return autopilot + await self.context_manager.load_index(self.ide.workspace_directory) class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 4fcab588..84b6b10b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -5,8 +5,10 @@ from .context import ContextProvider from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from .models import Models from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider class SlashCommand(BaseModel): @@ -46,6 +48,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None + policy_override: Optional[Policy] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py deleted file mode 100644 index 7c2a8ce0..00000000 --- a/continuedev/src/continuedev/core/policy.py +++ /dev/null @@ -1,78 +0,0 @@ -from textwrap import dedent -from typing import Union - -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep - - -def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - """ - Parses a slash command, returning the command name and the rest of the input. - """ - if inp.startswith("/"): - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - - for slash_command in config.slash_commands: - if slash_command.name == command_name[1:]: - params = slash_command.params - params["user_input"] = after_command - try: - return slash_command.step(**params) - except TypeError as e: - raise Exception( - f"Incorrect params used for slash command '{command_name}': {e}") - return None - - -def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - for custom_cmd in config.custom_commands: - if custom_cmd.name == command_name[1:]: - slash_command = parse_slash_command(custom_cmd.prompt, config) - if slash_command is not None: - return slash_command - return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) - return None - - -class DefaultPolicy(Policy): - def next(self, config: ContinueConfig, history: History) -> Step: - # At the very start, run initial Steps spcecified in the config - if history.get_current() is None: - return ( - MessageStep(name="Welcome to Continue", message=dedent("""\ - - Highlight code section and ask a question or give instructions - - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - - Use `/help` to ask questions about how to use Continue""")) >> - WelcomeStep() >> - # CreateCodebaseIndexChroma() >> - StepsOnStartupStep()) - - observation = history.get_current().observation - if observation is not None and isinstance(observation, UserInputObservation): - # This could be defined with ObservationTypePolicy. Ergonomics not right though. - user_input = observation.user_input - - slash_command = parse_slash_command(user_input, config) - if slash_command is not None: - return slash_command - - custom_command = parse_custom_command(user_input, config) - if custom_command is not None: - return custom_command - - if user_input.startswith("/edit"): - return EditHighlightedCodeStep(user_input=user_input[5:]) - - return SimpleChatStep() - - return None diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index be7008c0..bf22d696 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -53,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK): formatted_err = '\n'.join(traceback.format_exception(e)) msg_step = MessageStep( name="Invalid Continue Config File", message=formatted_err) - msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```" + msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)." sdk.history.add_node(HistoryNode( step=msg_step, observation=None, |
