diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 69 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 96 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 53 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 34 | 
5 files changed, 222 insertions, 32 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 05a8a8f2..7b0661a5 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -108,7 +108,7 @@ class Autopilot(ContinueBaseModel):                  HighlightedCodeContextProvider(ide=self.ide),                  FileContextProvider(workspace_dir=self.ide.workspace_directory),              ], -            self.ide.workspace_directory, +            self.continue_sdk,          )          if full_state is not None: diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 7172883f..4a5e64cc 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,7 +1,7 @@  import asyncio  import time  from abc import abstractmethod -from typing import Dict, List +from typing import Awaitable, Callable, Dict, List  from meilisearch_python_async import Client  from pydantic import BaseModel @@ -12,6 +12,13 @@ from ..libs.util.telemetry import posthog_logger  from ..server.meilisearch_server import poll_meilisearch_running  from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId + +class ContinueSDK(BaseModel): +    """To avoid circular imports""" + +    ... + +  SEARCH_INDEX_NAME = "continue_context_items" @@ -24,9 +31,22 @@ class ContextProvider(BaseModel):      """      title: str +    sdk: ContinueSDK = None +    delete_documents: Callable[[List[str]], Awaitable] = None +    update_documents: Callable[[List[ContextItem], str], Awaitable] = None      selected_items: List[ContextItem] = [] +    async def start(self, sdk: ContinueSDK, delete_documents, update_documents): +        """ +        Starts the context provider. + +        Default implementation sets the sdk. +        """ +        self.sdk = sdk +        self.delete_documents = delete_documents +        self.update_documents = update_documents +      async def get_selected_items(self) -> List[ContextItem]:          """          Returns all of the selected ContextItems. @@ -168,9 +188,7 @@ class ContextManager:          self.context_providers = {}          self.provider_titles = set() -    async def start( -        self, context_providers: List[ContextProvider], workspace_directory: str -    ): +    async def start(self, context_providers: List[ContextProvider], sdk: ContinueSDK):          """          Starts the context manager.          """ @@ -189,17 +207,56 @@ class ContextManager:                  self.context_providers = {                      prov.title: prov for prov in context_providers                  } +                for provider in context_providers: +                    await provider.start( +                        sdk, +                        ContextManager.delete_documents, +                        ContextManager.update_documents, +                    ) +                  logger.debug("Loading Meilisearch index...") -                await self.load_index(workspace_directory) +                await self.load_index(sdk.ide.workspace_directory)                  logger.debug("Loaded Meilisearch index")              except asyncio.TimeoutError: -                logger.warning("MeiliSearch did not start within 5 seconds") +                logger.warning("MeiliSearch did not start within 20 seconds")                  logger.warning(                      "MeiliSearch not running, avoiding any dependent context providers"                  )          create_async_task(start_meilisearch(context_providers)) +    @staticmethod +    async def update_documents(context_items: List[ContextItem], workspace_dir: str): +        """ +        Updates the documents in the search index. +        """ +        documents = [ +            { +                "id": item.description.id.to_string(), +                "name": item.description.name, +                "description": item.description.description, +                "content": item.content, +                "workspace_dir": workspace_dir, +            } +            for item in context_items +        ] +        async with Client("http://localhost:7700") as search_client: +            await asyncio.wait_for( +                search_client.index(SEARCH_INDEX_NAME).add_documents(documents), +                timeout=5, +            ) + +    @staticmethod +    async def delete_documents(ids): +        """ +        Deletes the documents in the search index. +        """ +        async with Client("http://localhost:7700") as search_client: +            await asyncio.wait_for( +                search_client.index(SEARCH_INDEX_NAME).delete_documents(ids), +                timeout=5, +            ) +      async def load_index(self, workspace_dir: str):          try:              async with Client("http://localhost:7700") as search_client: diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 33e20662..3cb63ca3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -1,3 +1,4 @@ +import asyncio  import os  from fnmatch import fnmatch  from typing import List @@ -53,6 +54,74 @@ class FileContextProvider(ContextProvider):          filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)      ) +    async def start(self, *args): +        await super().start(*args) + +        async def on_file_saved(filepath: str, contents: str): +            item = await self.get_context_item_for_filepath(filepath) +            await self.update_documents([item], self.sdk.ide.workspace_directory) + +        async def on_files_created(filepaths: List[str]): +            items = await asyncio.gather( +                *[ +                    self.get_context_item_for_filepath(filepath) +                    for filepath in filepaths +                ] +            ) +            items = [item for item in items if item is not None] +            await self.update_documents(items, self.sdk.ide.workspace_directory) + +        async def on_files_deleted(filepaths: List[str]): +            ids = [self.get_id_for_filepath(filepath) for filepath in filepaths] + +            await self.delete_documents(ids) + +        async def on_files_renamed(old_filepaths: List[str], new_filepaths: List[str]): +            old_ids = [self.get_id_for_filepath(filepath) for filepath in old_filepaths] +            new_docs = await asyncio.gather( +                *[ +                    self.get_context_item_for_filepath(filepath) +                    for filepath in new_filepaths +                ] +            ) +            new_docs = [doc for doc in new_docs if doc is not None] + +            await self.delete_documents(old_ids) +            await self.update_documents(new_docs, self.sdk.ide.workspace_directory) + +        self.sdk.ide.subscribeToFileSaved(on_file_saved) +        self.sdk.ide.subscribeToFilesCreated(on_files_created) +        self.sdk.ide.subscribeToFilesDeleted(on_files_deleted) +        self.sdk.ide.subscribeToFilesRenamed(on_files_renamed) + +    def get_id_for_filepath(self, absolute_filepath: str) -> str: +        return remove_meilisearch_disallowed_chars(absolute_filepath) + +    async def get_context_item_for_filepath( +        self, absolute_filepath: str +    ) -> ContextItem: +        content = get_file_contents(absolute_filepath) +        if content is None: +            return None + +        relative_to_workspace = os.path.relpath( +            absolute_filepath, self.sdk.ide.workspace_directory +        ) + +        return ContextItem( +            content=content[: min(2000, len(content))], +            description=ContextItemDescription( +                name=os.path.basename(absolute_filepath), +                # We should add the full path to the ContextItem +                # It warrants a data modeling discussion and has no immediate use case +                description=relative_to_workspace, +                id=ContextItemId( +                    provider_title=self.title, +                    item_id=self.get_id_for_filepath(absolute_filepath), +                ), +            ), +        ) +      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:          absolute_filepaths: List[str] = []          for root, dir_names, file_names in os.walk(workspace_dir): @@ -72,27 +141,8 @@ class FileContextProvider(ContextProvider):          items = []          for absolute_filepath in absolute_filepaths: -            content = get_file_contents(absolute_filepath) -            if content is None: -                continue  # no pun intended - -            relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) - -            items.append( -                ContextItem( -                    content=content[: min(2000, len(content))], -                    description=ContextItemDescription( -                        name=os.path.basename(absolute_filepath), -                        # We should add the full path to the ContextItem -                        # It warrants a data modeling discussion and has no immediate use case -                        description=relative_to_workspace, -                        id=ContextItemId( -                            provider_title=self.title, -                            item_id=remove_meilisearch_disallowed_chars( -                                absolute_filepath -                            ), -                        ), -                    ), -                ) -            ) +            item = await self.get_context_item_for_filepath(absolute_filepath) +            if item is not None: +                items.append(item) +          return items diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 5d85d57d..610a1a48 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -4,7 +4,7 @@ import json  import os  import traceback  import uuid -from typing import Any, Coroutine, List, Type, TypeVar, Union +from typing import Any, Callable, Coroutine, List, Type, TypeVar, Union  import nest_asyncio  from fastapi import APIRouter, WebSocket @@ -247,6 +247,14 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              self.workspace_directory = data["workspaceDirectory"]          elif message_type == "uniqueId":              self.unique_id = data["uniqueId"] +        elif message_type == "filesCreated": +            self.onFilesCreated(data["filepaths"]) +        elif message_type == "filesDeleted": +            self.onFilesDeleted(data["filepaths"]) +        elif message_type == "filesRenamed": +            self.onFilesRenamed(data["old_filepaths"], data["new_filepaths"]) +        elif message_type == "fileSaved": +            self.onFileSaved(data["filepath"], data["contents"])          else:              raise ValueError("Unknown message type", message_type) @@ -365,6 +373,49 @@ class IdeProtocolServer(AbstractIdeProtocolServer):                  autopilot.handle_highlighted_code(range_in_files), self.on_error              ) +    ## Subscriptions ## + +    _files_created_callbacks = [] +    _files_deleted_callbacks = [] +    _files_renamed_callbacks = [] +    _file_saved_callbacks = [] + +    def call_callback(self, callback, *args, **kwargs): +        if asyncio.iscoroutinefunction(callback): +            create_async_task(callback(*args, **kwargs), self.on_error) +        else: +            callback(*args, **kwargs) + +    def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): +        self._files_created_callbacks.append(callback) + +    def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): +        self._files_deleted_callbacks.append(callback) + +    def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): +        self._files_renamed_callbacks.append(callback) + +    def subscribeToFileSaved(self, callback: Callable[[str, str], None]): +        self._file_saved_callbacks.append(callback) + +    def onFilesCreated(self, filepaths: List[str]): +        for callback in self._files_created_callbacks: +            self.call_callback(callback, filepaths) + +    def onFilesDeleted(self, filepaths: List[str]): +        for callback in self._files_deleted_callbacks: +            self.call_callback(callback, filepaths) + +    def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): +        for callback in self._files_renamed_callbacks: +            self.call_callback(callback, old_filepaths, new_filepaths) + +    def onFileSaved(self, filepath: str, contents: str): +        for callback in self._file_saved_callbacks: +            self.call_callback(callback, filepath, contents) + +    ## END Subscriptions ## +      def onMainUserInput(self, input: str):          if autopilot := self.__get_autopilot():              create_async_task(autopilot.accept_user_input(input), self.on_error) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index f63fecf8..435c82e2 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,5 @@  from abc import ABC, abstractmethod -from typing import Any, List, Union +from typing import Any, Callable, List, Union  from fastapi import WebSocket @@ -115,5 +115,37 @@ class AbstractIdeProtocolServer(ABC):      async def showDiff(self, filepath: str, replacement: str, step_index: int):          """Show a diff""" +    @abstractmethod +    def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): +        """Subscribe to files created event""" + +    @abstractmethod +    def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): +        """Subscribe to files deleted event""" + +    @abstractmethod +    def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): +        """Subscribe to files renamed event""" + +    @abstractmethod +    def subscribeToFileSaved(self, callback: Callable[[str, str], None]): +        """Subscribe to file saved event""" + +    @abstractmethod +    def onFilesCreated(self, filepaths: List[str]): +        """Called when files are created""" + +    @abstractmethod +    def onFilesDeleted(self, filepaths: List[str]): +        """Called when files are deleted""" + +    @abstractmethod +    def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): +        """Called when files are renamed""" + +    @abstractmethod +    def onFileSaved(self, filepath: str, contents: str): +        """Called when a file is saved""" +      workspace_directory: str      unique_id: str | 
