diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-07-25 23:52:12 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-25 23:52:12 -0700 |
commit | 2b69bf6f1fc2e06b16b718358ceed4911d6e87c3 (patch) | |
tree | c27c630c64f2890698512606e2dd8acac9c0c8b6 /continuedev/src | |
parent | f0df0fdc1fb7d8e65e27abe633da1831b8172b35 (diff) | |
parent | 03da423e5abdf32c5bf9755ffd2e1c7be98e6454 (diff) | |
download | sncontinue-2b69bf6f1fc2e06b16b718358ceed4911d6e87c3.tar.gz sncontinue-2b69bf6f1fc2e06b16b718358ceed4911d6e87c3.tar.bz2 sncontinue-2b69bf6f1fc2e06b16b718358ceed4911d6e87c3.zip |
Merge pull request #296 from continuedev/config-py-2
Config py 2
Diffstat (limited to 'continuedev/src')
8 files changed, 51 insertions, 42 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index d018c29e..42a58423 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -69,7 +69,7 @@ class Autopilot(ContinueBaseModel): autopilot.continue_sdk = await ContinueSDK.create(autopilot) # Load documents into the search index - autopilot.context_manager = ContextManager( + autopilot.context_manager = await ContextManager.create( autopilot.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=ide), FileContextProvider(workspace_dir=ide.workspace_directory) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 67bba651..7d302656 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Dict, List -import meilisearch +from meilisearch_python_async import Client from pydantic import BaseModel @@ -50,21 +50,21 @@ class ContextProvider(BaseModel): """ return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] - async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem: + async def get_item(self, id: ContextItemId, query: str, search_client: Client) -> ContextItem: """ Returns the ContextItem with the given id. Default implementation uses the search index to get the item. """ - result = search_client.index( + result = await search_client.index( SEARCH_INDEX_NAME).get_document(id.to_string()) return ContextItem( description=ContextItemDescription( - name=result.name, - description=result.description, + name=result["name"], + description=result["description"], id=id ), - content=result.content + content=result["content"] ) async def delete_context_with_ids(self, ids: List[ContextItemId]): @@ -85,7 +85,7 @@ class ContextProvider(BaseModel): """ self.selected_items = [] - async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client): + async def add_context_item(self, id: ContextItemId, query: str, search_client: Client): """ Adds the given ContextItem to the list of ContextItems. @@ -126,21 +126,26 @@ class ContextManager: """ return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) - search_client: meilisearch.Client - - def __init__(self, context_providers: List[ContextProvider]): - self.search_client = meilisearch.Client('http://localhost:7700') - - # If meilisearch isn't running, don't use any ContextProviders that might depend on it - if not check_meilisearch_running(): - context_providers = list( - filter(lambda cp: cp.title == "code", context_providers)) + search_client: Client + def __init__(self, context_providers: List[ContextProvider], search_client: Client): + self.search_client = search_client self.context_providers = { prov.title: prov for prov in context_providers} self.provider_titles = { provider.title for provider in context_providers} + @classmethod + async def create(cls, context_providers: List[ContextProvider]): + search_client = Client('http://localhost:7700') + health = await search_client.health() + if not health.status == "available": + print("MeiliSearch not running, avoiding any dependent context providers") + context_providers = list( + filter(lambda cp: cp.title == "code", context_providers)) + + return cls(context_providers, search_client) + async def load_index(self): for _, provider in self.context_providers.items(): context_items = await provider.provide_context_items() @@ -154,8 +159,7 @@ class ContextManager: for item in context_items ] if len(documents) > 0: - self.search_client.index( - SEARCH_INDEX_NAME).add_documents(documents) + await self.search_client.index(SEARCH_INDEX_NAME).add_documents(documents) # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: # """ diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index fc0af7ba..6222ec6a 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -21,7 +21,7 @@ class FileContextProvider(ContextProvider): title = "file" workspace_dir: str - ignore_patterns: List[str] = list(map(lambda folder: f"**/{folder}", [ + ignore_patterns: List[str] = [ ".git", ".vscode", ".idea", @@ -35,7 +35,10 @@ class FileContextProvider(ContextProvider): "target", "out", "bin", - ])) + ".pytest_cache", + ".vscode-test", + ".continue", + ] async def provide_context_items(self) -> List[ContextItem]: filepaths = [] diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 23d4fc86..426c0804 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,7 +1,7 @@ import os from typing import Any, Dict, List -import meilisearch +from meilisearch_python_async import Client from ...core.main import ChatMessage from ...models.filesystem import RangeInFile, RangeInFileWithContents from ...core.context import ContextItem, ContextItemDescription, ContextItemId @@ -187,5 +187,5 @@ class HighlightedCodeContextProvider(BaseModel): for hr in self.highlighted_ranges: hr.item.editing = hr.item.description.id.to_string() in ids - async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client, prev: List[ContextItem] = None) -> List[ContextItem]: + async def add_context_item(self, id: ContextItemId, query: str, search_client: Client, prev: List[ContextItem] = None) -> List[ContextItem]: raise NotImplementedError() diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index fa203c28..c0957395 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -61,12 +61,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): "data": data }) - async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: try: return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) except asyncio.TimeoutError: raise Exception( - "GUI Protocol _receive_json timed out after 5 seconds") + "GUI Protocol _receive_json timed out after 20 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index d6a28c92..cf8b32a1 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import traceback import asyncio +from .meilisearch_server import start_meilisearch from ..libs.util.telemetry import posthog_logger from ..libs.util.queue import AsyncSubscriptionQueue from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem @@ -139,6 +140,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): continue message_type = message["messageType"] data = message["data"] + print("Received message while initializing", message_type) if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": @@ -153,17 +155,18 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: return + print("Sending IDE message: ", message_type) await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: try: return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) except asyncio.TimeoutError: raise Exception( - "IDE Protocol _receive_json timed out after 5 seconds") + "IDE Protocol _receive_json timed out after 20 seconds", message_type) async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -432,6 +435,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer): @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: + # Start meilisearch + try: + await start_meilisearch() + except Exception as e: + print("Failed to start MeiliSearch") + print(e) + await websocket.accept() print("Accepted websocket connection from, ", websocket.client) await websocket.send_json({"messageType": "connected", "data": {}}) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 7ee64041..0b59d4fe 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,7 +1,5 @@ import asyncio -import subprocess import time -import meilisearch import psutil import os from fastapi import FastAPI @@ -13,7 +11,6 @@ import argparse from .ide import router as ide_router from .gui import router as gui_router from .session_manager import session_manager -from .meilisearch_server import start_meilisearch app = FastAPI() @@ -87,13 +84,8 @@ if __name__ == "__main__": # cpu_thread = threading.Thread(target=cpu_usage_loop) # cpu_thread.start() - try: - start_meilisearch() - except Exception as e: - print("Failed to start MeiliSearch") - print(e) - run_server() except Exception as e: + print("Error starting Continue server: ", e) cleanup() raise e diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 232b6243..286019e1 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -2,7 +2,7 @@ import os import shutil import subprocess -import meilisearch +from meilisearch_python_async import Client from ..libs.util.paths import getServerFolderPath @@ -41,14 +41,14 @@ def ensure_meilisearch_installed(): f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) -def check_meilisearch_running() -> bool: +async def check_meilisearch_running() -> bool: """ Checks if MeiliSearch is running. """ try: - client = meilisearch.Client('http://localhost:7700') - resp = client.health() + client = Client('http://localhost:7700') + resp = await client.health() if resp["status"] != "available": return False return True @@ -56,7 +56,7 @@ def check_meilisearch_running() -> bool: return False -def start_meilisearch(): +async def start_meilisearch(): """ Starts the MeiliSearch server, wait for it. """ @@ -71,7 +71,7 @@ def start_meilisearch(): ensure_meilisearch_installed() # Check if MeiliSearch is running - if not check_meilisearch_running(): + if not await check_meilisearch_running(): print("Starting MeiliSearch...") subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) |