summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-29 22:40:04 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-29 22:40:04 -0700
commitc6a12550ffca1ffe35630e7aa9af6913ddbe0675 (patch)
tree6450574104c42c76ef168c54a92af9ecbb7337c5 /continuedev/src/continuedev/core
parent17566c66e0a01ad3c38ece974e44c1c71a9188de (diff)
downloadsncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.tar.gz
sncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.tar.bz2
sncontinue-c6a12550ffca1ffe35630e7aa9af6913ddbe0675.zip
feat: :sparkles: EmbeddingContextProvider
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py24
-rw-r--r--continuedev/src/continuedev/core/config.py11
-rw-r--r--continuedev/src/continuedev/core/policy.py80
3 files changed, 17 insertions, 98 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 3f25e64e..12339f9b 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
from .context import ContextManager
+from ..plugins.policies.default import DefaultPolicy
from ..plugins.context_providers.file import FileContextProvider
from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str:
class Autopilot(ContinueBaseModel):
- policy: Policy
ide: AbstractIdeProtocolServer
+
+ policy: Policy = DefaultPolicy()
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
@@ -64,20 +66,18 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @classmethod
- async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
- autopilot = cls(ide=ide, policy=policy)
- autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ async def start(self):
+ self.continue_sdk = await ContinueSDK.create(self)
+ if override_policy := self.continue_sdk.config.policy_override:
+ self.policy = override_policy
# Load documents into the search index
- autopilot.context_manager = await ContextManager.create(
- autopilot.continue_sdk.config.context_providers + [
- HighlightedCodeContextProvider(ide=ide),
- FileContextProvider(workspace_dir=ide.workspace_directory)
+ self.context_manager = await ContextManager.create(
+ self.continue_sdk.config.context_providers + [
+ HighlightedCodeContextProvider(ide=self.ide),
+ FileContextProvider(workspace_dir=self.ide.workspace_directory)
])
- await autopilot.context_manager.load_index()
-
- return autopilot
+ await self.context_manager.load_index()
class Config:
arbitrary_types_allowed = True
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 9fbda824..fe0946cd 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,10 +1,8 @@
-import json
-import os
-from .main import Step
-from .context import ContextProvider
from pydantic import BaseModel, validator
-from typing import List, Literal, Optional, Dict, Type, Union
-import yaml
+from typing import List, Literal, Optional, Dict, Type
+
+from .main import Policy, Step
+from .context import ContextProvider
class SlashCommand(BaseModel):
@@ -51,6 +49,7 @@ class ContinueConfig(BaseModel):
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
openai_server_info: Optional[OpenAIServerInfo] = None
+ policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
deleted file mode 100644
index d90177b5..00000000
--- a/continuedev/src/continuedev/core/policy.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from textwrap import dedent
-from typing import Union
-
-from ..plugins.steps.chat import SimpleChatStep
-from ..plugins.steps.welcome import WelcomeStep
-from .config import ContinueConfig
-from ..plugins.steps.steps_on_startup import StepsOnStartupStep
-from .main import Step, History, Policy
-from .observation import UserInputObservation
-from ..plugins.steps.core.core import MessageStep
-from ..plugins.steps.custom_command import CustomCommandStep
-from ..plugins.steps.main import EditHighlightedCodeStep
-
-
-def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
- """
- Parses a slash command, returning the command name and the rest of the input.
- """
- if inp.startswith("/"):
- command_name = inp.split(" ")[0]
- after_command = " ".join(inp.split(" ")[1:])
-
- for slash_command in config.slash_commands:
- if slash_command.name == command_name[1:]:
- params = slash_command.params
- params["user_input"] = after_command
- try:
- return slash_command.step(**params)
- except TypeError as e:
- raise Exception(
- f"Incorrect params used for slash command '{command_name}': {e}")
- return None
-
-
-def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
- command_name = inp.split(" ")[0]
- after_command = " ".join(inp.split(" ")[1:])
- for custom_cmd in config.custom_commands:
- if custom_cmd.name == command_name[1:]:
- slash_command = parse_slash_command(custom_cmd.prompt, config)
- if slash_command is not None:
- return slash_command
- return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name)
- return None
-
-
-class DefaultPolicy(Policy):
- ran_code_last: bool = False
-
- def next(self, config: ContinueConfig, history: History) -> Step:
- # At the very start, run initial Steps spcecified in the config
- if history.get_current() is None:
- return (
- MessageStep(name="Welcome to Continue", message=dedent("""\
- - Highlight code section and ask a question or give instructions
- - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
- - Use `/help` to ask questions about how to use Continue""")) >>
- WelcomeStep() >>
- # CreateCodebaseIndexChroma() >>
- StepsOnStartupStep())
-
- observation = history.get_current().observation
- if observation is not None and isinstance(observation, UserInputObservation):
- # This could be defined with ObservationTypePolicy. Ergonomics not right though.
- user_input = observation.user_input
-
- slash_command = parse_slash_command(user_input, config)
- if slash_command is not None:
- return slash_command
-
- custom_command = parse_custom_command(user_input, config)
- if custom_command is not None:
- return custom_command
-
- if user_input.startswith("/edit"):
- return EditHighlightedCodeStep(user_input=user_input[5:])
-
- return SimpleChatStep()
-
- return None