summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-01 23:16:13 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-01 23:16:13 -0700
commitdc2b90d848f5fc53a18ad481ba196ac9708de8ec (patch)
treeb8da4f4e7d31e362b39c6f9db38dec3a70afb235 /continuedev/src/continuedev/core
parentf02b5c10876e3c0ee40d3c095564f675c9296bdf (diff)
parent01ed2c7eb2d3417b2c190eea105008372f49a7c6 (diff)
downloadsncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.gz
sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.bz2
sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.zip
Merge branch 'main' into package-python
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py17
-rw-r--r--continuedev/src/continuedev/core/context.py17
2 files changed, 20 insertions, 14 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index d92c51cd..a3dd854e 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,7 +1,7 @@
from functools import cached_property
import traceback
import time
-from typing import Any, Callable, Coroutine, Dict, List, Union
+from typing import Callable, Coroutine, Dict, List, Union
from aiohttp import ClientPayloadError
from pydantic import root_validator
@@ -54,7 +54,7 @@ class Autopilot(ContinueBaseModel):
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
- context_manager: Union[ContextManager, None] = None
+ context_manager: ContextManager = ContextManager()
continue_sdk: ContinueSDK = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
@@ -66,19 +66,22 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
+ started: bool = False
+
async def start(self):
self.continue_sdk = await ContinueSDK.create(self)
if override_policy := self.continue_sdk.config.policy_override:
self.policy = override_policy
# Load documents into the search index
- self.context_manager = await ContextManager.create(
+ await self.context_manager.start(
self.continue_sdk.config.context_providers + [
HighlightedCodeContextProvider(ide=self.ide),
FileContextProvider(workspace_dir=self.ide.workspace_directory)
])
await self.context_manager.load_index(self.ide.workspace_directory)
+ self.started = True
class Config:
arbitrary_types_allowed = True
@@ -98,7 +101,7 @@ class Autopilot(ContinueBaseModel):
user_input_queue=self._main_user_input_queue,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
- "code"].adding_highlighted_code if self.context_manager is not None else False,
+ "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,
selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],
)
self.full_state = full_state
@@ -201,7 +204,7 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
async def set_editing_at_ids(self, ids: List[str]):
- self.context_manager.context_providers["code"].set_editing_at_ids(ids)
+ await self.context_manager.context_providers["code"].set_editing_at_ids(ids)
await self.update_subscribers()
async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
@@ -244,7 +247,7 @@ class Autopilot(ContinueBaseModel):
try:
observation = await step(self.continue_sdk)
except Exception as e:
- if self.history.timeline[index_of_history_node].deleted:
+ if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted:
# If step was deleted/cancelled, don't show error or allow retry
return None
@@ -301,7 +304,7 @@ class Autopilot(ContinueBaseModel):
self._step_depth -= 1
# Add observation to history, unless already attached error observation
- if not caught_error:
+ if not caught_error and index_of_history_node < len(self.history.timeline):
self.history.timeline[index_of_history_node].observation = observation
self.history.timeline[index_of_history_node].active = False
await self.update_subscribers()
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index e968c35c..3f5f6fd3 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -133,14 +133,19 @@ class ContextManager:
"""
return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])
- def __init__(self, context_providers: List[ContextProvider]):
+ def __init__(self):
+ self.context_providers = {}
+ self.provider_titles = set()
+
+ async def start(self, context_providers: List[ContextProvider]):
+ """
+ Starts the context manager.
+ """
self.context_providers = {
prov.title: prov for prov in context_providers}
self.provider_titles = {
provider.title for provider in context_providers}
- @classmethod
- async def create(cls, context_providers: List[ContextProvider]):
async with Client('http://localhost:7700') as search_client:
meilisearch_running = True
try:
@@ -154,10 +159,8 @@ class ContextManager:
if not meilisearch_running:
logger.warning(
"MeiliSearch not running, avoiding any dependent context providers")
- context_providers = list(
- filter(lambda cp: cp.title == "code", context_providers))
-
- return cls(context_providers)
+ self.context_providers = list(
+ filter(lambda cp: cp.title == "code", self.context_providers))
async def load_index(self, workspace_dir: str):
for _, provider in self.context_providers.items():