summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py6
-rw-r--r--continuedev/src/continuedev/core/context.py17
2 files changed, 13 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index d92c51cd..99b9185f 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -54,7 +54,7 @@ class Autopilot(ContinueBaseModel):
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
- context_manager: Union[ContextManager, None] = None
+ context_manager: ContextManager = ContextManager()
continue_sdk: ContinueSDK = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
@@ -72,7 +72,7 @@ class Autopilot(ContinueBaseModel):
self.policy = override_policy
# Load documents into the search index
- self.context_manager = await ContextManager.create(
+ await self.context_manager.start(
self.continue_sdk.config.context_providers + [
HighlightedCodeContextProvider(ide=self.ide),
FileContextProvider(workspace_dir=self.ide.workspace_directory)
@@ -98,7 +98,7 @@ class Autopilot(ContinueBaseModel):
user_input_queue=self._main_user_input_queue,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
- "code"].adding_highlighted_code if self.context_manager is not None else False,
+ "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,
selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],
)
self.full_state = full_state
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index e968c35c..3f5f6fd3 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -133,14 +133,19 @@ class ContextManager:
"""
return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])
- def __init__(self, context_providers: List[ContextProvider]):
+ def __init__(self):
+ self.context_providers = {}
+ self.provider_titles = set()
+
+ async def start(self, context_providers: List[ContextProvider]):
+ """
+ Starts the context manager.
+ """
self.context_providers = {
prov.title: prov for prov in context_providers}
self.provider_titles = {
provider.title for provider in context_providers}
- @classmethod
- async def create(cls, context_providers: List[ContextProvider]):
async with Client('http://localhost:7700') as search_client:
meilisearch_running = True
try:
@@ -154,10 +159,8 @@ class ContextManager:
if not meilisearch_running:
logger.warning(
"MeiliSearch not running, avoiding any dependent context providers")
- context_providers = list(
- filter(lambda cp: cp.title == "code", context_providers))
-
- return cls(context_providers)
+ self.context_providers = list(
+ filter(lambda cp: cp.title == "code", self.context_providers))
async def load_index(self, workspace_dir: str):
for _, provider in self.context_providers.items():