diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-28 01:02:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 01:02:52 -0700 |
commit | 95363a5b52f3bf73531ac76b00178fa79ca97661 (patch) | |
tree | 9b9c1614556f1f0d21f363e6a9fe950069affb5d /continuedev/src/continuedev/server | |
parent | d4acf4bb11dbd7d3d6210e2949d21143d721e81e (diff) | |
download | sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.gz sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.bz2 sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.zip |
Past input (#513)
* feat: :construction: use ComboBox in place of UserInputContainer
* feat: :construction: adding context to previous inputs steps
* feat: :sparkles: preview context items on click
* feat: :construction: more work on context items ui
* style: :construction: working out the details of ctx item buttons
* feat: :sparkles: getting the final details
* fix: :bug: fix height of ctx items bar
* fix: :bug: last couple of details
* fix: :bug: pass model param through to hf inference api
* fix: :loud_sound: better logging for timeout
* feat: :sparkles: option to set the meilisearch url
* fix: :bug: fix height of past inputs
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 35 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 24 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 23 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 30 |
4 files changed, 75 insertions, 37 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 9d2ea47a..10f6974f 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -104,7 +104,7 @@ class GUIProtocolServer: elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"]) + self.on_delete_context_with_ids(data["ids"], data.get("index", None)) elif message_type == "toggle_adding_highlighted_code": self.on_toggle_adding_highlighted_code() elif message_type == "set_editing_at_ids": @@ -112,9 +112,11 @@ class GUIProtocolServer: elif message_type == "show_logs_at_index": self.on_show_logs_at_index(data["index"]) elif message_type == "show_context_virtual_file": - self.show_context_virtual_file() + self.show_context_virtual_file(data.get("index", None)) elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) + elif message_type == "select_context_item_at_index": + self.select_context_item_at_index(data["id"], data["query"], data["index"]) elif message_type == "load_session": self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": @@ -171,9 +173,9 @@ class GUIProtocolServer: self.on_error, ) - def on_delete_context_with_ids(self, ids: List[str]): + def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None): create_async_task( - self.session.autopilot.delete_context_with_ids(ids), self.on_error + self.session.autopilot.delete_context_with_ids(ids, index), self.on_error ) def on_toggle_adding_highlighted_code(self): @@ -188,7 +190,7 @@ class GUIProtocolServer: def on_show_logs_at_index(self, index: int): name = "Continue Context" logs = "\n\n############################################\n\n".join( - ["This is the prompt sent to the LLM during this step"] + ["This is the prompt that was sent to the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs ) create_async_task( @@ -196,12 +198,20 @@ class GUIProtocolServer: ) posthog_logger.capture_event("show_logs_at_index", {}) - def show_context_virtual_file(self): + def show_context_virtual_file(self, index: Optional[int] = None): async def async_stuff(): - msgs = await self.session.autopilot.continue_sdk.get_chat_context() + if index is None: + context_items = ( + await self.session.autopilot.context_manager.get_selected_items() + ) + elif index < len(self.session.autopilot.continue_sdk.history.timeline): + context_items = self.session.autopilot.continue_sdk.history.timeline[ + index + ].context_used + ctx = "\n\n-----------------------------------\n\n".join( - ["This is the exact context that will be passed to the LLM"] - + list(map(lambda x: x.content, msgs)) + ["These are the context items that will be passed to the LLM"] + + list(map(lambda x: x.content, context_items)) ) await self.session.autopilot.ide.showVirtualFile( "Continue - Selected Context", ctx @@ -218,6 +228,13 @@ class GUIProtocolServer: self.session.autopilot.select_context_item(id, query), self.on_error ) + def select_context_item_at_index(self, id: str, query: str, index: int): + """Called when user selects an item from the dropdown for prev UserInputStep""" + create_async_task( + self.session.autopilot.select_context_item_at_index(id, query, index), + self.on_error, + ) + def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): new_session_id = await session_manager.load_session( diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index d4f0690b..32bd0f0c 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server +from ..core.main import ContinueCustomException from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.logging import logger @@ -39,7 +40,6 @@ from ..models.filesystem_edit import ( from ..plugins.steps.core.core import DisplayErrorStep from .gui import session_manager from .ide_protocol import AbstractIdeProtocolServer -from .meilisearch_server import start_meilisearch from .session_manager import SessionManager nest_asyncio.apply() @@ -201,21 +201,24 @@ class IdeProtocolServer(AbstractIdeProtocolServer): except RuntimeError as e: logger.warning(f"Error sending IDE message, websocket probably closed: {e}") - async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: + async def _receive_json( + self, message_type: str, timeout: int = 20, message=None + ) -> Any: try: return await asyncio.wait_for( self.sub_queue.get(message_type), timeout=timeout ) except asyncio.TimeoutError: - raise Exception( - f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}" + raise ContinueCustomException( + title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}", + message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}", ) async def _send_and_receive_json( self, data: Any, resp_model: Type[T], message_type: str ) -> T: await self._send_json(message_type, data) - resp = await self._receive_json(message_type) + resp = await self._receive_json(message_type, message=data) return resp_model.parse_obj(resp) async def handle_json(self, message_type: str, data: Any): @@ -597,17 +600,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): logger.debug(f"Accepted websocket connection from {websocket.client}") await websocket.send_json({"messageType": "connected", "data": {}}) - # Start meilisearch - try: - - async def on_err(e): - logger.debug(f"Failed to start MeiliSearch: {e}") - - create_async_task(start_meilisearch(), on_err) - except Exception as e: - logger.debug("Failed to start MeiliSearch") - logger.debug(e) - # Message handler def handle_msg(msg): message = json.loads(msg) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index bbae2bb2..aa6c8944 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,6 +1,8 @@ import argparse import asyncio import atexit +from contextlib import asynccontextmanager +from typing import Optional import uvicorn from fastapi import FastAPI @@ -9,10 +11,21 @@ from fastapi.middleware.cors import CORSMiddleware from ..libs.util.logging import logger from .gui import router as gui_router from .ide import router as ide_router +from .meilisearch_server import start_meilisearch, stop_meilisearch from .session_manager import router as sessions_router from .session_manager import session_manager -app = FastAPI() +meilisearch_url_global = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await start_meilisearch(url=meilisearch_url_global) + yield + stop_meilisearch() + + +app = FastAPI(lifespan=lifespan) app.include_router(ide_router) app.include_router(gui_router) @@ -34,7 +47,13 @@ def health(): return {"status": "ok"} -def run_server(port: int = 65432, host: str = "127.0.0.1"): +def run_server( + port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None +): + global meilisearch_url_global + + meilisearch_url_global = meilisearch_url + config = uvicorn.Config(app, host=host, port=port) server = uvicorn.Server(config) server.run() diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 5e6cdd53..8929b69d 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -2,9 +2,11 @@ import asyncio import os import shutil import subprocess +from typing import Optional import aiofiles import aiohttp +import psutil from meilisearch_python_async import Client from ..libs.util.logging import logger @@ -89,13 +91,22 @@ async def ensure_meilisearch_installed() -> bool: return True +meilisearch_process = None +DEFAULT_MEILISEARCH_URL = "http://localhost:7700" +meilisearch_url = DEFAULT_MEILISEARCH_URL + + +def get_meilisearch_url(): + return meilisearch_url + + async def check_meilisearch_running() -> bool: """ Checks if MeiliSearch is running. """ try: - async with Client("http://localhost:7700") as client: + async with Client(meilisearch_url) as client: try: resp = await client.health() if resp.status != "available": @@ -117,14 +128,16 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool: await asyncio.sleep(frequency) -meilisearch_process = None - - -async def start_meilisearch(): +async def start_meilisearch(url: Optional[str] = None): """ Starts the MeiliSearch server, wait for it. """ - global meilisearch_process + global meilisearch_process, meilisearch_url + + if url is not None: + logger.debug("Using MeiliSearch at URL: " + url) + meilisearch_url = url + return serverPath = getServerFolderPath() @@ -157,9 +170,6 @@ def stop_meilisearch(): meilisearch_process = None -import psutil - - def kill_proc(port): for proc in psutil.process_iter(): try: @@ -180,4 +190,4 @@ def kill_proc(port): async def restart_meilisearch(): stop_meilisearch() kill_proc(7700) - await start_meilisearch() + await start_meilisearch(url=meilisearch_url) |