diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-08-22 13:12:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-22 13:12:40 -0700 |
commit | b6435e1e479edb1e4f049098dc8522e944317f2a (patch) | |
tree | 29935c22e503ed1a64ac6db1d899dff75915c6ed /continuedev | |
parent | 7ed6b61a5b629b5e16fbd5c90ff0ad78300a77c2 (diff) | |
download | sncontinue-b6435e1e479edb1e4f049098dc8522e944317f2a.tar.gz sncontinue-b6435e1e479edb1e4f049098dc8522e944317f2a.tar.bz2 sncontinue-b6435e1e479edb1e4f049098dc8522e944317f2a.zip |
feat: :children_crossing: keep file context up to data by listening for filesystem events (#396)
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 69 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 96 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 53 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 34 |
5 files changed, 222 insertions, 32 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 05a8a8f2..7b0661a5 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -108,7 +108,7 @@ class Autopilot(ContinueBaseModel): HighlightedCodeContextProvider(ide=self.ide), FileContextProvider(workspace_dir=self.ide.workspace_directory), ], - self.ide.workspace_directory, + self.continue_sdk, ) if full_state is not None: diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 7172883f..4a5e64cc 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,7 +1,7 @@ import asyncio import time from abc import abstractmethod -from typing import Dict, List +from typing import Awaitable, Callable, Dict, List from meilisearch_python_async import Client from pydantic import BaseModel @@ -12,6 +12,13 @@ from ..libs.util.telemetry import posthog_logger from ..server.meilisearch_server import poll_meilisearch_running from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId + +class ContinueSDK(BaseModel): + """To avoid circular imports""" + + ... + + SEARCH_INDEX_NAME = "continue_context_items" @@ -24,9 +31,22 @@ class ContextProvider(BaseModel): """ title: str + sdk: ContinueSDK = None + delete_documents: Callable[[List[str]], Awaitable] = None + update_documents: Callable[[List[ContextItem], str], Awaitable] = None selected_items: List[ContextItem] = [] + async def start(self, sdk: ContinueSDK, delete_documents, update_documents): + """ + Starts the context provider. + + Default implementation sets the sdk. + """ + self.sdk = sdk + self.delete_documents = delete_documents + self.update_documents = update_documents + async def get_selected_items(self) -> List[ContextItem]: """ Returns all of the selected ContextItems. @@ -168,9 +188,7 @@ class ContextManager: self.context_providers = {} self.provider_titles = set() - async def start( - self, context_providers: List[ContextProvider], workspace_directory: str - ): + async def start(self, context_providers: List[ContextProvider], sdk: ContinueSDK): """ Starts the context manager. """ @@ -189,17 +207,56 @@ class ContextManager: self.context_providers = { prov.title: prov for prov in context_providers } + for provider in context_providers: + await provider.start( + sdk, + ContextManager.delete_documents, + ContextManager.update_documents, + ) + logger.debug("Loading Meilisearch index...") - await self.load_index(workspace_directory) + await self.load_index(sdk.ide.workspace_directory) logger.debug("Loaded Meilisearch index") except asyncio.TimeoutError: - logger.warning("MeiliSearch did not start within 5 seconds") + logger.warning("MeiliSearch did not start within 20 seconds") logger.warning( "MeiliSearch not running, avoiding any dependent context providers" ) create_async_task(start_meilisearch(context_providers)) + @staticmethod + async def update_documents(context_items: List[ContextItem], workspace_dir: str): + """ + Updates the documents in the search index. + """ + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content, + "workspace_dir": workspace_dir, + } + for item in context_items + ] + async with Client("http://localhost:7700") as search_client: + await asyncio.wait_for( + search_client.index(SEARCH_INDEX_NAME).add_documents(documents), + timeout=5, + ) + + @staticmethod + async def delete_documents(ids): + """ + Deletes the documents in the search index. + """ + async with Client("http://localhost:7700") as search_client: + await asyncio.wait_for( + search_client.index(SEARCH_INDEX_NAME).delete_documents(ids), + timeout=5, + ) + async def load_index(self, workspace_dir: str): try: async with Client("http://localhost:7700") as search_client: diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 33e20662..3cb63ca3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -1,3 +1,4 @@ +import asyncio import os from fnmatch import fnmatch from typing import List @@ -53,6 +54,74 @@ class FileContextProvider(ContextProvider): filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) ) + async def start(self, *args): + await super().start(*args) + + async def on_file_saved(filepath: str, contents: str): + item = await self.get_context_item_for_filepath(filepath) + await self.update_documents([item], self.sdk.ide.workspace_directory) + + async def on_files_created(filepaths: List[str]): + items = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in filepaths + ] + ) + items = [item for item in items if item is not None] + await self.update_documents(items, self.sdk.ide.workspace_directory) + + async def on_files_deleted(filepaths: List[str]): + ids = [self.get_id_for_filepath(filepath) for filepath in filepaths] + + await self.delete_documents(ids) + + async def on_files_renamed(old_filepaths: List[str], new_filepaths: List[str]): + old_ids = [self.get_id_for_filepath(filepath) for filepath in old_filepaths] + new_docs = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in new_filepaths + ] + ) + new_docs = [doc for doc in new_docs if doc is not None] + + await self.delete_documents(old_ids) + await self.update_documents(new_docs, self.sdk.ide.workspace_directory) + + self.sdk.ide.subscribeToFileSaved(on_file_saved) + self.sdk.ide.subscribeToFilesCreated(on_files_created) + self.sdk.ide.subscribeToFilesDeleted(on_files_deleted) + self.sdk.ide.subscribeToFilesRenamed(on_files_renamed) + + def get_id_for_filepath(self, absolute_filepath: str) -> str: + return remove_meilisearch_disallowed_chars(absolute_filepath) + + async def get_context_item_for_filepath( + self, absolute_filepath: str + ) -> ContextItem: + content = get_file_contents(absolute_filepath) + if content is None: + return None + + relative_to_workspace = os.path.relpath( + absolute_filepath, self.sdk.ide.workspace_directory + ) + + return ContextItem( + content=content[: min(2000, len(content))], + description=ContextItemDescription( + name=os.path.basename(absolute_filepath), + # We should add the full path to the ContextItem + # It warrants a data modeling discussion and has no immediate use case + description=relative_to_workspace, + id=ContextItemId( + provider_title=self.title, + item_id=self.get_id_for_filepath(absolute_filepath), + ), + ), + ) + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: absolute_filepaths: List[str] = [] for root, dir_names, file_names in os.walk(workspace_dir): @@ -72,27 +141,8 @@ class FileContextProvider(ContextProvider): items = [] for absolute_filepath in absolute_filepaths: - content = get_file_contents(absolute_filepath) - if content is None: - continue # no pun intended - - relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) - - items.append( - ContextItem( - content=content[: min(2000, len(content))], - description=ContextItemDescription( - name=os.path.basename(absolute_filepath), - # We should add the full path to the ContextItem - # It warrants a data modeling discussion and has no immediate use case - description=relative_to_workspace, - id=ContextItemId( - provider_title=self.title, - item_id=remove_meilisearch_disallowed_chars( - absolute_filepath - ), - ), - ), - ) - ) + item = await self.get_context_item_for_filepath(absolute_filepath) + if item is not None: + items.append(item) + return items diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 5d85d57d..610a1a48 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -4,7 +4,7 @@ import json import os import traceback import uuid -from typing import Any, Coroutine, List, Type, TypeVar, Union +from typing import Any, Callable, Coroutine, List, Type, TypeVar, Union import nest_asyncio from fastapi import APIRouter, WebSocket @@ -247,6 +247,14 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": self.unique_id = data["uniqueId"] + elif message_type == "filesCreated": + self.onFilesCreated(data["filepaths"]) + elif message_type == "filesDeleted": + self.onFilesDeleted(data["filepaths"]) + elif message_type == "filesRenamed": + self.onFilesRenamed(data["old_filepaths"], data["new_filepaths"]) + elif message_type == "fileSaved": + self.onFileSaved(data["filepath"], data["contents"]) else: raise ValueError("Unknown message type", message_type) @@ -365,6 +373,49 @@ class IdeProtocolServer(AbstractIdeProtocolServer): autopilot.handle_highlighted_code(range_in_files), self.on_error ) + ## Subscriptions ## + + _files_created_callbacks = [] + _files_deleted_callbacks = [] + _files_renamed_callbacks = [] + _file_saved_callbacks = [] + + def call_callback(self, callback, *args, **kwargs): + if asyncio.iscoroutinefunction(callback): + create_async_task(callback(*args, **kwargs), self.on_error) + else: + callback(*args, **kwargs) + + def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): + self._files_created_callbacks.append(callback) + + def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): + self._files_deleted_callbacks.append(callback) + + def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): + self._files_renamed_callbacks.append(callback) + + def subscribeToFileSaved(self, callback: Callable[[str, str], None]): + self._file_saved_callbacks.append(callback) + + def onFilesCreated(self, filepaths: List[str]): + for callback in self._files_created_callbacks: + self.call_callback(callback, filepaths) + + def onFilesDeleted(self, filepaths: List[str]): + for callback in self._files_deleted_callbacks: + self.call_callback(callback, filepaths) + + def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): + for callback in self._files_renamed_callbacks: + self.call_callback(callback, old_filepaths, new_filepaths) + + def onFileSaved(self, filepath: str, contents: str): + for callback in self._file_saved_callbacks: + self.call_callback(callback, filepath, contents) + + ## END Subscriptions ## + def onMainUserInput(self, input: str): if autopilot := self.__get_autopilot(): create_async_task(autopilot.accept_user_input(input), self.on_error) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index f63fecf8..435c82e2 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Union +from typing import Any, Callable, List, Union from fastapi import WebSocket @@ -115,5 +115,37 @@ class AbstractIdeProtocolServer(ABC): async def showDiff(self, filepath: str, replacement: str, step_index: int): """Show a diff""" + @abstractmethod + def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): + """Subscribe to files created event""" + + @abstractmethod + def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): + """Subscribe to files deleted event""" + + @abstractmethod + def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): + """Subscribe to files renamed event""" + + @abstractmethod + def subscribeToFileSaved(self, callback: Callable[[str, str], None]): + """Subscribe to file saved event""" + + @abstractmethod + def onFilesCreated(self, filepaths: List[str]): + """Called when files are created""" + + @abstractmethod + def onFilesDeleted(self, filepaths: List[str]): + """Called when files are deleted""" + + @abstractmethod + def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): + """Called when files are renamed""" + + @abstractmethod + def onFileSaved(self, filepath: str, contents: str): + """Called when a file is saved""" + workspace_directory: str unique_id: str |