diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-01 23:16:13 -0700 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-01 23:16:13 -0700 |
| commit | dc2b90d848f5fc53a18ad481ba196ac9708de8ec (patch) | |
| tree | b8da4f4e7d31e362b39c6f9db38dec3a70afb235 /continuedev/src/continuedev/core | |
| parent | f02b5c10876e3c0ee40d3c095564f675c9296bdf (diff) | |
| parent | 01ed2c7eb2d3417b2c190eea105008372f49a7c6 (diff) | |
| download | sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.gz sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.bz2 sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.zip | |
Merge branch 'main' into package-python
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 17 |
2 files changed, 20 insertions, 14 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index d92c51cd..a3dd854e 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,7 +1,7 @@ from functools import cached_property import traceback import time -from typing import Any, Callable, Coroutine, Dict, List, Union +from typing import Callable, Coroutine, Dict, List, Union from aiohttp import ClientPayloadError from pydantic import root_validator @@ -54,7 +54,7 @@ class Autopilot(ContinueBaseModel): history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None - context_manager: Union[ContextManager, None] = None + context_manager: ContextManager = ContextManager() continue_sdk: ContinueSDK = None _on_update_callbacks: List[Callable[[FullState], None]] = [] @@ -66,19 +66,22 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() + started: bool = False + async def start(self): self.continue_sdk = await ContinueSDK.create(self) if override_policy := self.continue_sdk.config.policy_override: self.policy = override_policy # Load documents into the search index - self.context_manager = await ContextManager.create( + await self.context_manager.start( self.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=self.ide), FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) await self.context_manager.load_index(self.ide.workspace_directory) + self.started = True class Config: arbitrary_types_allowed = True @@ -98,7 +101,7 @@ class Autopilot(ContinueBaseModel): user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ - "code"].adding_highlighted_code if self.context_manager is not None else False, + "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False, selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], ) self.full_state = full_state @@ -201,7 +204,7 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() async def set_editing_at_ids(self, ids: List[str]): - self.context_manager.context_providers["code"].set_editing_at_ids(ids) + await self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: @@ -244,7 +247,7 @@ class Autopilot(ContinueBaseModel): try: observation = await step(self.continue_sdk) except Exception as e: - if self.history.timeline[index_of_history_node].deleted: + if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted: # If step was deleted/cancelled, don't show error or allow retry return None @@ -301,7 +304,7 @@ class Autopilot(ContinueBaseModel): self._step_depth -= 1 # Add observation to history, unless already attached error observation - if not caught_error: + if not caught_error and index_of_history_node < len(self.history.timeline): self.history.timeline[index_of_history_node].observation = observation self.history.timeline[index_of_history_node].active = False await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index e968c35c..3f5f6fd3 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -133,14 +133,19 @@ class ContextManager: """ return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) - def __init__(self, context_providers: List[ContextProvider]): + def __init__(self): + self.context_providers = {} + self.provider_titles = set() + + async def start(self, context_providers: List[ContextProvider]): + """ + Starts the context manager. + """ self.context_providers = { prov.title: prov for prov in context_providers} self.provider_titles = { provider.title for provider in context_providers} - @classmethod - async def create(cls, context_providers: List[ContextProvider]): async with Client('http://localhost:7700') as search_client: meilisearch_running = True try: @@ -154,10 +159,8 @@ class ContextManager: if not meilisearch_running: logger.warning( "MeiliSearch not running, avoiding any dependent context providers") - context_providers = list( - filter(lambda cp: cp.title == "code", context_providers)) - - return cls(context_providers) + self.context_providers = list( + filter(lambda cp: cp.title == "code", self.context_providers)) async def load_index(self, workspace_dir: str): for _, provider in self.context_providers.items(): |
