diff options
Diffstat (limited to 'continuedev/src')
8 files changed, 51 insertions, 42 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index d018c29e..42a58423 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -69,7 +69,7 @@ class Autopilot(ContinueBaseModel):          autopilot.continue_sdk = await ContinueSDK.create(autopilot)          # Load documents into the search index -        autopilot.context_manager = ContextManager( +        autopilot.context_manager = await ContextManager.create(              autopilot.continue_sdk.config.context_providers + [                  HighlightedCodeContextProvider(ide=ide),                  FileContextProvider(workspace_dir=ide.workspace_directory) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 67bba651..7d302656 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,7 +1,7 @@  from abc import abstractmethod  from typing import Dict, List -import meilisearch +from meilisearch_python_async import Client  from pydantic import BaseModel @@ -50,21 +50,21 @@ class ContextProvider(BaseModel):          """          return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] -    async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem: +    async def get_item(self, id: ContextItemId, query: str, search_client: Client) -> ContextItem:          """          Returns the ContextItem with the given id.          Default implementation uses the search index to get the item.          """ -        result = search_client.index( +        result = await search_client.index(              SEARCH_INDEX_NAME).get_document(id.to_string())          return ContextItem(              description=ContextItemDescription( -                name=result.name, -                description=result.description, +                name=result["name"], +                description=result["description"],                  id=id              ), -            content=result.content +            content=result["content"]          )      async def delete_context_with_ids(self, ids: List[ContextItemId]): @@ -85,7 +85,7 @@ class ContextProvider(BaseModel):          """          self.selected_items = [] -    async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client): +    async def add_context_item(self, id: ContextItemId, query: str, search_client: Client):          """          Adds the given ContextItem to the list of ContextItems. @@ -126,21 +126,26 @@ class ContextManager:          """          return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) -    search_client: meilisearch.Client - -    def __init__(self, context_providers: List[ContextProvider]): -        self.search_client = meilisearch.Client('http://localhost:7700') - -        # If meilisearch isn't running, don't use any ContextProviders that might depend on it -        if not check_meilisearch_running(): -            context_providers = list( -                filter(lambda cp: cp.title == "code", context_providers)) +    search_client: Client +    def __init__(self, context_providers: List[ContextProvider], search_client: Client): +        self.search_client = search_client          self.context_providers = {              prov.title: prov for prov in context_providers}          self.provider_titles = {              provider.title for provider in context_providers} +    @classmethod +    async def create(cls, context_providers: List[ContextProvider]): +        search_client = Client('http://localhost:7700') +        health = await search_client.health() +        if not health.status == "available": +            print("MeiliSearch not running, avoiding any dependent context providers") +            context_providers = list( +                filter(lambda cp: cp.title == "code", context_providers)) + +        return cls(context_providers, search_client) +      async def load_index(self):          for _, provider in self.context_providers.items():              context_items = await provider.provide_context_items() @@ -154,8 +159,7 @@ class ContextManager:                  for item in context_items              ]              if len(documents) > 0: -                self.search_client.index( -                    SEARCH_INDEX_NAME).add_documents(documents) +                await self.search_client.index(SEARCH_INDEX_NAME).add_documents(documents)      # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:      #     """ diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index fc0af7ba..6222ec6a 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -21,7 +21,7 @@ class FileContextProvider(ContextProvider):      title = "file"      workspace_dir: str -    ignore_patterns: List[str] = list(map(lambda folder: f"**/{folder}", [ +    ignore_patterns: List[str] = [          ".git",          ".vscode",          ".idea", @@ -35,7 +35,10 @@ class FileContextProvider(ContextProvider):          "target",          "out",          "bin", -    ])) +        ".pytest_cache", +        ".vscode-test", +        ".continue", +    ]      async def provide_context_items(self) -> List[ContextItem]:          filepaths = [] diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 23d4fc86..426c0804 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,7 +1,7 @@  import os  from typing import Any, Dict, List -import meilisearch +from meilisearch_python_async import Client  from ...core.main import ChatMessage  from ...models.filesystem import RangeInFile, RangeInFileWithContents  from ...core.context import ContextItem, ContextItemDescription, ContextItemId @@ -187,5 +187,5 @@ class HighlightedCodeContextProvider(BaseModel):          for hr in self.highlighted_ranges:              hr.item.editing = hr.item.description.id.to_string() in ids -    async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client, prev: List[ContextItem] = None) -> List[ContextItem]: +    async def add_context_item(self, id: ContextItemId, query: str, search_client: Client, prev: List[ContextItem] = None) -> List[ContextItem]:          raise NotImplementedError() diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index fa203c28..c0957395 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -61,12 +61,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              "data": data          }) -    async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: +    async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:          try:              return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)          except asyncio.TimeoutError:              raise Exception( -                "GUI Protocol _receive_json timed out after 5 seconds") +                "GUI Protocol _receive_json timed out after 20 seconds")      async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:          await self._send_json(message_type, data) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index d6a28c92..cf8b32a1 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -10,6 +10,7 @@ from pydantic import BaseModel  import traceback  import asyncio +from .meilisearch_server import start_meilisearch  from ..libs.util.telemetry import posthog_logger  from ..libs.util.queue import AsyncSubscriptionQueue  from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem @@ -139,6 +140,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):                  continue              message_type = message["messageType"]              data = message["data"] +            print("Received message while initializing", message_type)              if message_type == "workspaceDirectory":                  self.workspace_directory = data["workspaceDirectory"]              elif message_type == "uniqueId": @@ -153,17 +155,18 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      async def _send_json(self, message_type: str, data: Any):          if self.websocket.application_state == WebSocketState.DISCONNECTED:              return +        print("Sending IDE message: ", message_type)          await self.websocket.send_json({              "messageType": message_type,              "data": data          }) -    async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: +    async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:          try:              return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)          except asyncio.TimeoutError:              raise Exception( -                "IDE Protocol _receive_json timed out after 5 seconds") +                "IDE Protocol _receive_json timed out after 20 seconds", message_type)      async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:          await self._send_json(message_type, data) @@ -432,6 +435,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer):  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session_id: str = None):      try: +        # Start meilisearch +        try: +            await start_meilisearch() +        except Exception as e: +            print("Failed to start MeiliSearch") +            print(e) +          await websocket.accept()          print("Accepted websocket connection from, ", websocket.client)          await websocket.send_json({"messageType": "connected", "data": {}}) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 7ee64041..0b59d4fe 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,7 +1,5 @@  import asyncio -import subprocess  import time -import meilisearch  import psutil  import os  from fastapi import FastAPI @@ -13,7 +11,6 @@ import argparse  from .ide import router as ide_router  from .gui import router as gui_router  from .session_manager import session_manager -from .meilisearch_server import start_meilisearch  app = FastAPI() @@ -87,13 +84,8 @@ if __name__ == "__main__":          # cpu_thread = threading.Thread(target=cpu_usage_loop)          # cpu_thread.start() -        try: -            start_meilisearch() -        except Exception as e: -            print("Failed to start MeiliSearch") -            print(e) -          run_server()      except Exception as e: +        print("Error starting Continue server: ", e)          cleanup()          raise e diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 232b6243..286019e1 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -2,7 +2,7 @@ import os  import shutil  import subprocess -import meilisearch +from meilisearch_python_async import Client  from ..libs.util.paths import getServerFolderPath @@ -41,14 +41,14 @@ def ensure_meilisearch_installed():              f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) -def check_meilisearch_running() -> bool: +async def check_meilisearch_running() -> bool:      """      Checks if MeiliSearch is running.      """      try: -        client = meilisearch.Client('http://localhost:7700') -        resp = client.health() +        client = Client('http://localhost:7700') +        resp = await client.health()          if resp["status"] != "available":              return False          return True @@ -56,7 +56,7 @@ def check_meilisearch_running() -> bool:          return False -def start_meilisearch(): +async def start_meilisearch():      """      Starts the MeiliSearch server, wait for it.      """ @@ -71,7 +71,7 @@ def start_meilisearch():      ensure_meilisearch_installed()      # Check if MeiliSearch is running -    if not check_meilisearch_running(): +    if not await check_meilisearch_running():          print("Starting MeiliSearch...")          subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL,                           stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) | 
