summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-28 01:02:52 -0700
committerGitHub <noreply@github.com>2023-09-28 01:02:52 -0700
commit95363a5b52f3bf73531ac76b00178fa79ca97661 (patch)
tree9b9c1614556f1f0d21f363e6a9fe950069affb5d /continuedev/src/continuedev/server
parentd4acf4bb11dbd7d3d6210e2949d21143d721e81e (diff)
downloadsncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.gz
sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.bz2
sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.zip
Past input (#513)
* feat: :construction: use ComboBox in place of UserInputContainer * feat: :construction: adding context to previous inputs steps * feat: :sparkles: preview context items on click * feat: :construction: more work on context items ui * style: :construction: working out the details of ctx item buttons * feat: :sparkles: getting the final details * fix: :bug: fix height of ctx items bar * fix: :bug: last couple of details * fix: :bug: pass model param through to hf inference api * fix: :loud_sound: better logging for timeout * feat: :sparkles: option to set the meilisearch url * fix: :bug: fix height of past inputs
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py35
-rw-r--r--continuedev/src/continuedev/server/ide.py24
-rw-r--r--continuedev/src/continuedev/server/main.py23
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py30
4 files changed, 75 insertions, 37 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 9d2ea47a..10f6974f 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -104,7 +104,7 @@ class GUIProtocolServer:
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
elif message_type == "delete_context_with_ids":
- self.on_delete_context_with_ids(data["ids"])
+ self.on_delete_context_with_ids(data["ids"], data.get("index", None))
elif message_type == "toggle_adding_highlighted_code":
self.on_toggle_adding_highlighted_code()
elif message_type == "set_editing_at_ids":
@@ -112,9 +112,11 @@ class GUIProtocolServer:
elif message_type == "show_logs_at_index":
self.on_show_logs_at_index(data["index"])
elif message_type == "show_context_virtual_file":
- self.show_context_virtual_file()
+ self.show_context_virtual_file(data.get("index", None))
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
+ elif message_type == "select_context_item_at_index":
+ self.select_context_item_at_index(data["id"], data["query"], data["index"])
elif message_type == "load_session":
self.load_session(data.get("session_id", None))
elif message_type == "edit_step_at_index":
@@ -171,9 +173,9 @@ class GUIProtocolServer:
self.on_error,
)
- def on_delete_context_with_ids(self, ids: List[str]):
+ def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None):
create_async_task(
- self.session.autopilot.delete_context_with_ids(ids), self.on_error
+ self.session.autopilot.delete_context_with_ids(ids, index), self.on_error
)
def on_toggle_adding_highlighted_code(self):
@@ -188,7 +190,7 @@ class GUIProtocolServer:
def on_show_logs_at_index(self, index: int):
name = "Continue Context"
logs = "\n\n############################################\n\n".join(
- ["This is the prompt sent to the LLM during this step"]
+ ["This is the prompt that was sent to the LLM during this step"]
+ self.session.autopilot.continue_sdk.history.timeline[index].logs
)
create_async_task(
@@ -196,12 +198,20 @@ class GUIProtocolServer:
)
posthog_logger.capture_event("show_logs_at_index", {})
- def show_context_virtual_file(self):
+ def show_context_virtual_file(self, index: Optional[int] = None):
async def async_stuff():
- msgs = await self.session.autopilot.continue_sdk.get_chat_context()
+ if index is None:
+ context_items = (
+ await self.session.autopilot.context_manager.get_selected_items()
+ )
+ elif index < len(self.session.autopilot.continue_sdk.history.timeline):
+ context_items = self.session.autopilot.continue_sdk.history.timeline[
+ index
+ ].context_used
+
ctx = "\n\n-----------------------------------\n\n".join(
- ["This is the exact context that will be passed to the LLM"]
- + list(map(lambda x: x.content, msgs))
+ ["These are the context items that will be passed to the LLM"]
+ + list(map(lambda x: x.content, context_items))
)
await self.session.autopilot.ide.showVirtualFile(
"Continue - Selected Context", ctx
@@ -218,6 +228,13 @@ class GUIProtocolServer:
self.session.autopilot.select_context_item(id, query), self.on_error
)
+ def select_context_item_at_index(self, id: str, query: str, index: int):
+ """Called when user selects an item from the dropdown for prev UserInputStep"""
+ create_async_task(
+ self.session.autopilot.select_context_item_at_index(id, query, index),
+ self.on_error,
+ )
+
def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
new_session_id = await session_manager.load_session(
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index d4f0690b..32bd0f0c 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
+from ..core.main import ContinueCustomException
from ..libs.util.create_async_task import create_async_task
from ..libs.util.devdata import dev_data_logger
from ..libs.util.logging import logger
@@ -39,7 +40,6 @@ from ..models.filesystem_edit import (
from ..plugins.steps.core.core import DisplayErrorStep
from .gui import session_manager
from .ide_protocol import AbstractIdeProtocolServer
-from .meilisearch_server import start_meilisearch
from .session_manager import SessionManager
nest_asyncio.apply()
@@ -201,21 +201,24 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
except RuntimeError as e:
logger.warning(f"Error sending IDE message, websocket probably closed: {e}")
- async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
+ async def _receive_json(
+ self, message_type: str, timeout: int = 20, message=None
+ ) -> Any:
try:
return await asyncio.wait_for(
self.sub_queue.get(message_type), timeout=timeout
)
except asyncio.TimeoutError:
- raise Exception(
- f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}"
+ raise ContinueCustomException(
+ title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}",
+ message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}",
)
async def _send_and_receive_json(
self, data: Any, resp_model: Type[T], message_type: str
) -> T:
await self._send_json(message_type, data)
- resp = await self._receive_json(message_type)
+ resp = await self._receive_json(message_type, message=data)
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
@@ -597,17 +600,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
logger.debug(f"Accepted websocket connection from {websocket.client}")
await websocket.send_json({"messageType": "connected", "data": {}})
- # Start meilisearch
- try:
-
- async def on_err(e):
- logger.debug(f"Failed to start MeiliSearch: {e}")
-
- create_async_task(start_meilisearch(), on_err)
- except Exception as e:
- logger.debug("Failed to start MeiliSearch")
- logger.debug(e)
-
# Message handler
def handle_msg(msg):
message = json.loads(msg)
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index bbae2bb2..aa6c8944 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,6 +1,8 @@
import argparse
import asyncio
import atexit
+from contextlib import asynccontextmanager
+from typing import Optional
import uvicorn
from fastapi import FastAPI
@@ -9,10 +11,21 @@ from fastapi.middleware.cors import CORSMiddleware
from ..libs.util.logging import logger
from .gui import router as gui_router
from .ide import router as ide_router
+from .meilisearch_server import start_meilisearch, stop_meilisearch
from .session_manager import router as sessions_router
from .session_manager import session_manager
-app = FastAPI()
+meilisearch_url_global = None
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ await start_meilisearch(url=meilisearch_url_global)
+ yield
+ stop_meilisearch()
+
+
+app = FastAPI(lifespan=lifespan)
app.include_router(ide_router)
app.include_router(gui_router)
@@ -34,7 +47,13 @@ def health():
return {"status": "ok"}
-def run_server(port: int = 65432, host: str = "127.0.0.1"):
+def run_server(
+ port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None
+):
+ global meilisearch_url_global
+
+ meilisearch_url_global = meilisearch_url
+
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
server.run()
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 5e6cdd53..8929b69d 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -2,9 +2,11 @@ import asyncio
import os
import shutil
import subprocess
+from typing import Optional
import aiofiles
import aiohttp
+import psutil
from meilisearch_python_async import Client
from ..libs.util.logging import logger
@@ -89,13 +91,22 @@ async def ensure_meilisearch_installed() -> bool:
return True
+meilisearch_process = None
+DEFAULT_MEILISEARCH_URL = "http://localhost:7700"
+meilisearch_url = DEFAULT_MEILISEARCH_URL
+
+
+def get_meilisearch_url():
+ return meilisearch_url
+
+
async def check_meilisearch_running() -> bool:
"""
Checks if MeiliSearch is running.
"""
try:
- async with Client("http://localhost:7700") as client:
+ async with Client(meilisearch_url) as client:
try:
resp = await client.health()
if resp.status != "available":
@@ -117,14 +128,16 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool:
await asyncio.sleep(frequency)
-meilisearch_process = None
-
-
-async def start_meilisearch():
+async def start_meilisearch(url: Optional[str] = None):
"""
Starts the MeiliSearch server, wait for it.
"""
- global meilisearch_process
+ global meilisearch_process, meilisearch_url
+
+ if url is not None:
+ logger.debug("Using MeiliSearch at URL: " + url)
+ meilisearch_url = url
+ return
serverPath = getServerFolderPath()
@@ -157,9 +170,6 @@ def stop_meilisearch():
meilisearch_process = None
-import psutil
-
-
def kill_proc(port):
for proc in psutil.process_iter():
try:
@@ -180,4 +190,4 @@ def kill_proc(port):
async def restart_meilisearch():
stop_meilisearch()
kill_proc(7700)
- await start_meilisearch()
+ await start_meilisearch(url=meilisearch_url)