summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/autopilot.py2
-rw-r--r--continuedev/src/continuedev/core/context.py69
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py96
-rw-r--r--continuedev/src/continuedev/server/ide.py53
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py34
-rw-r--r--extension/src/continueIdeClient.ts29
6 files changed, 251 insertions, 32 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 05a8a8f2..7b0661a5 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -108,7 +108,7 @@ class Autopilot(ContinueBaseModel):
HighlightedCodeContextProvider(ide=self.ide),
FileContextProvider(workspace_dir=self.ide.workspace_directory),
],
- self.ide.workspace_directory,
+ self.continue_sdk,
)
if full_state is not None:
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 7172883f..4a5e64cc 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -1,7 +1,7 @@
import asyncio
import time
from abc import abstractmethod
-from typing import Dict, List
+from typing import Awaitable, Callable, Dict, List
from meilisearch_python_async import Client
from pydantic import BaseModel
@@ -12,6 +12,13 @@ from ..libs.util.telemetry import posthog_logger
from ..server.meilisearch_server import poll_meilisearch_running
from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
+
+class ContinueSDK(BaseModel):
+ """To avoid circular imports"""
+
+ ...
+
+
SEARCH_INDEX_NAME = "continue_context_items"
@@ -24,9 +31,22 @@ class ContextProvider(BaseModel):
"""
title: str
+ sdk: ContinueSDK = None
+ delete_documents: Callable[[List[str]], Awaitable] = None
+ update_documents: Callable[[List[ContextItem], str], Awaitable] = None
selected_items: List[ContextItem] = []
+ async def start(self, sdk: ContinueSDK, delete_documents, update_documents):
+ """
+ Starts the context provider.
+
+ Default implementation sets the sdk.
+ """
+ self.sdk = sdk
+ self.delete_documents = delete_documents
+ self.update_documents = update_documents
+
async def get_selected_items(self) -> List[ContextItem]:
"""
Returns all of the selected ContextItems.
@@ -168,9 +188,7 @@ class ContextManager:
self.context_providers = {}
self.provider_titles = set()
- async def start(
- self, context_providers: List[ContextProvider], workspace_directory: str
- ):
+ async def start(self, context_providers: List[ContextProvider], sdk: ContinueSDK):
"""
Starts the context manager.
"""
@@ -189,17 +207,56 @@ class ContextManager:
self.context_providers = {
prov.title: prov for prov in context_providers
}
+ for provider in context_providers:
+ await provider.start(
+ sdk,
+ ContextManager.delete_documents,
+ ContextManager.update_documents,
+ )
+
logger.debug("Loading Meilisearch index...")
- await self.load_index(workspace_directory)
+ await self.load_index(sdk.ide.workspace_directory)
logger.debug("Loaded Meilisearch index")
except asyncio.TimeoutError:
- logger.warning("MeiliSearch did not start within 5 seconds")
+ logger.warning("MeiliSearch did not start within 20 seconds")
logger.warning(
"MeiliSearch not running, avoiding any dependent context providers"
)
create_async_task(start_meilisearch(context_providers))
+ @staticmethod
+ async def update_documents(context_items: List[ContextItem], workspace_dir: str):
+ """
+ Updates the documents in the search index.
+ """
+ documents = [
+ {
+ "id": item.description.id.to_string(),
+ "name": item.description.name,
+ "description": item.description.description,
+ "content": item.content,
+ "workspace_dir": workspace_dir,
+ }
+ for item in context_items
+ ]
+ async with Client("http://localhost:7700") as search_client:
+ await asyncio.wait_for(
+ search_client.index(SEARCH_INDEX_NAME).add_documents(documents),
+ timeout=5,
+ )
+
+ @staticmethod
+ async def delete_documents(ids):
+ """
+ Deletes the documents in the search index.
+ """
+ async with Client("http://localhost:7700") as search_client:
+ await asyncio.wait_for(
+ search_client.index(SEARCH_INDEX_NAME).delete_documents(ids),
+ timeout=5,
+ )
+
async def load_index(self, workspace_dir: str):
try:
async with Client("http://localhost:7700") as search_client:
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 33e20662..3cb63ca3 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -1,3 +1,4 @@
+import asyncio
import os
from fnmatch import fnmatch
from typing import List
@@ -53,6 +54,74 @@ class FileContextProvider(ContextProvider):
filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)
)
+ async def start(self, *args):
+ await super().start(*args)
+
+ async def on_file_saved(filepath: str, contents: str):
+ item = await self.get_context_item_for_filepath(filepath)
+ await self.update_documents([item], self.sdk.ide.workspace_directory)
+
+ async def on_files_created(filepaths: List[str]):
+ items = await asyncio.gather(
+ *[
+ self.get_context_item_for_filepath(filepath)
+ for filepath in filepaths
+ ]
+ )
+ items = [item for item in items if item is not None]
+ await self.update_documents(items, self.sdk.ide.workspace_directory)
+
+ async def on_files_deleted(filepaths: List[str]):
+ ids = [self.get_id_for_filepath(filepath) for filepath in filepaths]
+
+ await self.delete_documents(ids)
+
+ async def on_files_renamed(old_filepaths: List[str], new_filepaths: List[str]):
+ old_ids = [self.get_id_for_filepath(filepath) for filepath in old_filepaths]
+ new_docs = await asyncio.gather(
+ *[
+ self.get_context_item_for_filepath(filepath)
+ for filepath in new_filepaths
+ ]
+ )
+ new_docs = [doc for doc in new_docs if doc is not None]
+
+ await self.delete_documents(old_ids)
+ await self.update_documents(new_docs, self.sdk.ide.workspace_directory)
+
+ self.sdk.ide.subscribeToFileSaved(on_file_saved)
+ self.sdk.ide.subscribeToFilesCreated(on_files_created)
+ self.sdk.ide.subscribeToFilesDeleted(on_files_deleted)
+ self.sdk.ide.subscribeToFilesRenamed(on_files_renamed)
+
+ def get_id_for_filepath(self, absolute_filepath: str) -> str:
+ return remove_meilisearch_disallowed_chars(absolute_filepath)
+
+ async def get_context_item_for_filepath(
+ self, absolute_filepath: str
+ ) -> ContextItem:
+ content = get_file_contents(absolute_filepath)
+ if content is None:
+ return None
+
+ relative_to_workspace = os.path.relpath(
+ absolute_filepath, self.sdk.ide.workspace_directory
+ )
+
+ return ContextItem(
+ content=content[: min(2000, len(content))],
+ description=ContextItemDescription(
+ name=os.path.basename(absolute_filepath),
+ # We should add the full path to the ContextItem
+ # It warrants a data modeling discussion and has no immediate use case
+ description=relative_to_workspace,
+ id=ContextItemId(
+ provider_title=self.title,
+ item_id=self.get_id_for_filepath(absolute_filepath),
+ ),
+ ),
+ )
+
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
absolute_filepaths: List[str] = []
for root, dir_names, file_names in os.walk(workspace_dir):
@@ -72,27 +141,8 @@ class FileContextProvider(ContextProvider):
items = []
for absolute_filepath in absolute_filepaths:
- content = get_file_contents(absolute_filepath)
- if content is None:
- continue # no pun intended
-
- relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir)
-
- items.append(
- ContextItem(
- content=content[: min(2000, len(content))],
- description=ContextItemDescription(
- name=os.path.basename(absolute_filepath),
- # We should add the full path to the ContextItem
- # It warrants a data modeling discussion and has no immediate use case
- description=relative_to_workspace,
- id=ContextItemId(
- provider_title=self.title,
- item_id=remove_meilisearch_disallowed_chars(
- absolute_filepath
- ),
- ),
- ),
- )
- )
+ item = await self.get_context_item_for_filepath(absolute_filepath)
+ if item is not None:
+ items.append(item)
+
return items
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 5d85d57d..610a1a48 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -4,7 +4,7 @@ import json
import os
import traceback
import uuid
-from typing import Any, Coroutine, List, Type, TypeVar, Union
+from typing import Any, Callable, Coroutine, List, Type, TypeVar, Union
import nest_asyncio
from fastapi import APIRouter, WebSocket
@@ -247,6 +247,14 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.workspace_directory = data["workspaceDirectory"]
elif message_type == "uniqueId":
self.unique_id = data["uniqueId"]
+ elif message_type == "filesCreated":
+ self.onFilesCreated(data["filepaths"])
+ elif message_type == "filesDeleted":
+ self.onFilesDeleted(data["filepaths"])
+ elif message_type == "filesRenamed":
+ self.onFilesRenamed(data["old_filepaths"], data["new_filepaths"])
+ elif message_type == "fileSaved":
+ self.onFileSaved(data["filepath"], data["contents"])
else:
raise ValueError("Unknown message type", message_type)
@@ -365,6 +373,49 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
autopilot.handle_highlighted_code(range_in_files), self.on_error
)
+ ## Subscriptions ##
+
+ _files_created_callbacks = []
+ _files_deleted_callbacks = []
+ _files_renamed_callbacks = []
+ _file_saved_callbacks = []
+
+ def call_callback(self, callback, *args, **kwargs):
+ if asyncio.iscoroutinefunction(callback):
+ create_async_task(callback(*args, **kwargs), self.on_error)
+ else:
+ callback(*args, **kwargs)
+
+ def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]):
+ self._files_created_callbacks.append(callback)
+
+ def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]):
+ self._files_deleted_callbacks.append(callback)
+
+ def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]):
+ self._files_renamed_callbacks.append(callback)
+
+ def subscribeToFileSaved(self, callback: Callable[[str, str], None]):
+ self._file_saved_callbacks.append(callback)
+
+ def onFilesCreated(self, filepaths: List[str]):
+ for callback in self._files_created_callbacks:
+ self.call_callback(callback, filepaths)
+
+ def onFilesDeleted(self, filepaths: List[str]):
+ for callback in self._files_deleted_callbacks:
+ self.call_callback(callback, filepaths)
+
+ def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]):
+ for callback in self._files_renamed_callbacks:
+ self.call_callback(callback, old_filepaths, new_filepaths)
+
+ def onFileSaved(self, filepath: str, contents: str):
+ for callback in self._file_saved_callbacks:
+ self.call_callback(callback, filepath, contents)
+
+ ## END Subscriptions ##
+
def onMainUserInput(self, input: str):
if autopilot := self.__get_autopilot():
create_async_task(autopilot.accept_user_input(input), self.on_error)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index f63fecf8..435c82e2 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, List, Union
+from typing import Any, Callable, List, Union
from fastapi import WebSocket
@@ -115,5 +115,37 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
+ @abstractmethod
+ def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]):
+ """Subscribe to files created event"""
+
+ @abstractmethod
+ def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]):
+ """Subscribe to files deleted event"""
+
+ @abstractmethod
+ def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]):
+ """Subscribe to files renamed event"""
+
+ @abstractmethod
+ def subscribeToFileSaved(self, callback: Callable[[str, str], None]):
+ """Subscribe to file saved event"""
+
+ @abstractmethod
+ def onFilesCreated(self, filepaths: List[str]):
+ """Called when files are created"""
+
+ @abstractmethod
+ def onFilesDeleted(self, filepaths: List[str]):
+ """Called when files are deleted"""
+
+ @abstractmethod
+ def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]):
+ """Called when files are renamed"""
+
+ @abstractmethod
+ def onFileSaved(self, filepath: str, contents: str):
+ """Called when a file is saved"""
+
workspace_directory: str
unique_id: str
diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts
index 666b8ba0..430bb9dd 100644
--- a/extension/src/continueIdeClient.ts
+++ b/extension/src/continueIdeClient.ts
@@ -115,6 +115,35 @@ class IdeProtocolClient {
// }
// });
+ // Listen for new file creation
+ vscode.workspace.onDidCreateFiles((event) => {
+ const filepaths = event.files.map((file) => file.fsPath);
+ this.messenger?.send("filesCreated", { filepaths });
+ });
+
+ // Listen for file deletion
+ vscode.workspace.onDidDeleteFiles((event) => {
+ const filepaths = event.files.map((file) => file.fsPath);
+ this.messenger?.send("filesDeleted", { filepaths });
+ });
+
+ // Listen for file renaming
+ vscode.workspace.onDidRenameFiles((event) => {
+ const oldFilepaths = event.files.map((file) => file.oldUri.fsPath);
+ const newFilepaths = event.files.map((file) => file.newUri.fsPath);
+ this.messenger?.send("filesRenamed", {
+ old_filepaths: oldFilepaths,
+ new_filepaths: newFilepaths,
+ });
+ });
+
+ // Listen for file saving
+ vscode.workspace.onDidSaveTextDocument((event) => {
+ const filepath = event.uri.fsPath;
+ const contents = event.getText();
+ this.messenger?.send("fileSaved", { filepath, contents });
+ });
+
// Setup listeners for any selection changes in open editors
vscode.window.onDidChangeTextEditorSelection((event) => {
if (!this.editorIsCode(event.textEditor)) {