summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-31 00:59:42 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-31 00:59:42 -0700
commit078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a (patch)
tree8d047405164b9157f4e948f68d8b000a1083efdf /continuedev/src/continuedev/core
parent72e83325a8eb5032c448a5e891c157987921ced2 (diff)
parentc51ad538deff06af6c9e5498b23e3536e18bfc4c (diff)
downloadsncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.tar.gz
sncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.tar.bz2
sncontinue-078b7d9a40d9cd0dd93a89184b3a00e2ff651d2a.zip
Merge branch 'main' into llm-object-config
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py23
-rw-r--r--continuedev/src/continuedev/core/config.py7
-rw-r--r--continuedev/src/continuedev/core/policy.py78
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
4 files changed, 18 insertions, 92 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index b4c951b8..de95a259 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
from .context import ContextManager
+from ..plugins.policies.default import DefaultPolicy
from ..plugins.context_providers.file import FileContextProvider
from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str:
class Autopilot(ContinueBaseModel):
- policy: Policy
ide: AbstractIdeProtocolServer
+
+ policy: Policy = DefaultPolicy()
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
@@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @classmethod
- async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
- autopilot = cls(ide=ide, policy=policy)
- autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ async def start(self):
+ self.continue_sdk = await ContinueSDK.create(self)
+ if override_policy := self.continue_sdk.config.policy_override:
+ self.policy = override_policy
# Load documents into the search index
- autopilot.context_manager = await ContextManager.create(
- autopilot.continue_sdk.config.context_providers + [
- HighlightedCodeContextProvider(ide=ide),
- FileContextProvider(workspace_dir=ide.workspace_directory)
+ self.context_manager = await ContextManager.create(
+ self.continue_sdk.config.context_providers + [
+ HighlightedCodeContextProvider(ide=self.ide),
+ FileContextProvider(workspace_dir=self.ide.workspace_directory)
])
- await autopilot.context_manager.load_index(ide.workspace_directory)
- return autopilot
+ await self.context_manager.load_index(self.ide.workspace_directory)
class Config:
arbitrary_types_allowed = True
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 4fcab588..84b6b10b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -5,8 +5,10 @@ from .context import ContextProvider
from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from .models import Models
from pydantic import BaseModel, validator
-from typing import List, Literal, Optional, Dict, Type, Union
-import yaml
+from typing import List, Literal, Optional, Dict, Type
+
+from .main import Policy, Step
+from .context import ContextProvider
class SlashCommand(BaseModel):
@@ -46,6 +48,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
+ policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
deleted file mode 100644
index 7c2a8ce0..00000000
--- a/continuedev/src/continuedev/core/policy.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from textwrap import dedent
-from typing import Union
-
-from ..plugins.steps.chat import SimpleChatStep
-from ..plugins.steps.welcome import WelcomeStep
-from .config import ContinueConfig
-from ..plugins.steps.steps_on_startup import StepsOnStartupStep
-from .main import Step, History, Policy
-from .observation import UserInputObservation
-from ..plugins.steps.core.core import MessageStep
-from ..plugins.steps.custom_command import CustomCommandStep
-from ..plugins.steps.main import EditHighlightedCodeStep
-
-
-def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
- """
- Parses a slash command, returning the command name and the rest of the input.
- """
- if inp.startswith("/"):
- command_name = inp.split(" ")[0]
- after_command = " ".join(inp.split(" ")[1:])
-
- for slash_command in config.slash_commands:
- if slash_command.name == command_name[1:]:
- params = slash_command.params
- params["user_input"] = after_command
- try:
- return slash_command.step(**params)
- except TypeError as e:
- raise Exception(
- f"Incorrect params used for slash command '{command_name}': {e}")
- return None
-
-
-def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
- command_name = inp.split(" ")[0]
- after_command = " ".join(inp.split(" ")[1:])
- for custom_cmd in config.custom_commands:
- if custom_cmd.name == command_name[1:]:
- slash_command = parse_slash_command(custom_cmd.prompt, config)
- if slash_command is not None:
- return slash_command
- return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name)
- return None
-
-
-class DefaultPolicy(Policy):
- def next(self, config: ContinueConfig, history: History) -> Step:
- # At the very start, run initial Steps spcecified in the config
- if history.get_current() is None:
- return (
- MessageStep(name="Welcome to Continue", message=dedent("""\
- - Highlight code section and ask a question or give instructions
- - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
- - Use `/help` to ask questions about how to use Continue""")) >>
- WelcomeStep() >>
- # CreateCodebaseIndexChroma() >>
- StepsOnStartupStep())
-
- observation = history.get_current().observation
- if observation is not None and isinstance(observation, UserInputObservation):
- # This could be defined with ObservationTypePolicy. Ergonomics not right though.
- user_input = observation.user_input
-
- slash_command = parse_slash_command(user_input, config)
- if slash_command is not None:
- return slash_command
-
- custom_command = parse_custom_command(user_input, config)
- if custom_command is not None:
- return custom_command
-
- if user_input.startswith("/edit"):
- return EditHighlightedCodeStep(user_input=user_input[5:])
-
- return SimpleChatStep()
-
- return None
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index be7008c0..bf22d696 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -53,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK):
formatted_err = '\n'.join(traceback.format_exception(e))
msg_step = MessageStep(
name="Invalid Continue Config File", message=formatted_err)
- msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```"
+ msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)."
sdk.history.add_node(HistoryNode(
step=msg_step,
observation=None,