summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-07-26 00:56:29 -0700
committerGitHub <noreply@github.com>2023-07-26 00:56:29 -0700
commitdef8e5612cd4c889a2e26d4152fffcf3d694abdf (patch)
treeca423625619b9d628651bcc9a395ba8f47fa03a6 /continuedev/src/continuedev/server
parentb759e2dbfe36b3e8873527b9736d64866da9b604 (diff)
parentd9a4ed993aad36464776c093333af1a310e5a492 (diff)
downloadsncontinue-def8e5612cd4c889a2e26d4152fffcf3d694abdf.tar.gz
sncontinue-def8e5612cd4c889a2e26d4152fffcf3d694abdf.tar.bz2
sncontinue-def8e5612cd4c889a2e26d4152fffcf3d694abdf.zip
Merge pull request #297 from continuedev/merge-config-py-TO-main
Merge config py to main
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py59
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py14
-rw-r--r--continuedev/src/continuedev/server/ide.py49
-rw-r--r--continuedev/src/continuedev/server/main.py20
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py77
-rw-r--r--continuedev/src/continuedev/server/session_manager.py6
6 files changed, 154 insertions, 71 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index ae57c0b6..c0957395 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -2,15 +2,15 @@ import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
-from typing import Any, List, Type, TypeVar, Union
+from typing import Any, List, Type, TypeVar
from pydantic import BaseModel
import traceback
from uvicorn.main import Server
-from .session_manager import SessionManager, session_manager, Session
+from .session_manager import session_manager, Session
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
-from ..libs.util.telemetry import capture_event
+from ..libs.util.telemetry import posthog_logger
from ..libs.util.create_async_task import create_async_task
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -61,12 +61,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
"data": data
})
- async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
try:
return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
except asyncio.TimeoutError:
raise Exception(
- "GUI Protocol _receive_json timed out after 5 seconds")
+ "GUI Protocol _receive_json timed out after 20 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -85,31 +85,23 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_reverse_to_index(data["index"])
elif message_type == "retry_at_index":
self.on_retry_at_index(data["index"])
- elif message_type == "change_default_model":
- self.on_change_default_model(data["model"])
elif message_type == "clear_history":
self.on_clear_history()
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
- elif message_type == "delete_context_at_indices":
- self.on_delete_context_at_indices(data["indices"])
+ elif message_type == "delete_context_with_ids":
+ self.on_delete_context_with_ids(data["ids"])
elif message_type == "toggle_adding_highlighted_code":
self.on_toggle_adding_highlighted_code()
elif message_type == "set_editing_at_indices":
self.on_set_editing_at_indices(data["indices"])
- elif message_type == "set_pinned_at_indices":
- self.on_set_pinned_at_indices(data["indices"])
elif message_type == "show_logs_at_index":
self.on_show_logs_at_index(data["index"])
+ elif message_type == "select_context_item":
+ self.select_context_item(data["id"], data["query"])
except Exception as e:
print(e)
- async def send_state_update(self):
- state = self.session.autopilot.get_full_state().dict()
- await self._send_json("state_update", {
- "state": state
- })
-
def on_main_input(self, input: str):
# Do something with user input
create_async_task(self.session.autopilot.accept_user_input(
@@ -132,10 +124,6 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(
self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id)
- def on_change_default_model(self, model: str):
- create_async_task(self.session.autopilot.change_default_model(
- model), self.session.autopilot.continue_sdk.ide.unique_id)
-
def on_clear_history(self):
create_async_task(self.session.autopilot.clear_history(
), self.session.autopilot.continue_sdk.ide.unique_id)
@@ -144,10 +132,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(self.session.autopilot.delete_at_index(
index), self.session.autopilot.continue_sdk.ide.unique_id)
- def on_delete_context_at_indices(self, indices: List[int]):
+ def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
- self.session.autopilot.delete_context_at_indices(
- indices), self.session.autopilot.continue_sdk.ide.unique_id
+ self.session.autopilot.delete_context_with_ids(
+ ids), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_toggle_adding_highlighted_code(self):
@@ -162,18 +150,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
indices), self.session.autopilot.continue_sdk.ide.unique_id
)
- def on_set_pinned_at_indices(self, indices: List[int]):
- create_async_task(
- self.session.autopilot.set_pinned_at_indices(
- indices), self.session.autopilot.continue_sdk.ide.unique_id
- )
-
def on_show_logs_at_index(self, index: int):
name = f"continue_logs.txt"
logs = "\n\n############################################\n\n".join(
["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs))
+ self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id)
+
+ def select_context_item(self, id: str, query: str):
+ """Called when user selects an item from the dropdown"""
+ create_async_task(
+ self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id)
@router.websocket("/ws")
@@ -188,11 +175,11 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.websocket = websocket
# Update any history that may have happened before connection
- await protocol.send_state_update()
+ await protocol.session.autopilot.update_subscribers()
while AppStatus.should_exit is False:
message = await websocket.receive_text()
- print("Received message", message)
+ print("Received GUI message", message)
if type(message) is str:
message = json.loads(message)
@@ -206,13 +193,13 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
- capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
+ posthog_logger.capture_event("gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
finally:
print("Closing gui websocket")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
- session_manager.persist_session(session.session_id)
+ await session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
index 9766fcd0..990833be 100644
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ b/continuedev/src/continuedev/server/gui_protocol.py
@@ -1,6 +1,8 @@
from typing import Any, Dict, List
from abc import ABC, abstractmethod
+from ..core.context import ContextItem
+
class AbstractGUIProtocolServer(ABC):
@abstractmethod
@@ -24,21 +26,17 @@ class AbstractGUIProtocolServer(ABC):
"""Called when the user inputs a step"""
@abstractmethod
- async def send_state_update(self, state: dict):
- """Send a state update to the client"""
-
- @abstractmethod
def on_retry_at_index(self, index: int):
"""Called when the user requests a retry at a previous index"""
@abstractmethod
- def on_change_default_model(self):
- """Called when the user requests to change the default model"""
-
- @abstractmethod
def on_clear_history(self):
"""Called when the user requests to clear the history"""
@abstractmethod
def on_delete_at_index(self, index: int):
"""Called when the user requests to delete a step at a given index"""
+
+ @abstractmethod
+ def select_context_item(self, id: str, query: str):
+ """Called when user selects an item from the dropdown"""
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index aeff5623..cf8b32a1 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -1,23 +1,25 @@
# This is a separate server from server/main.py
-from functools import cached_property
import json
import os
-from typing import Any, Dict, List, Type, TypeVar, Union
+from typing import Any, List, Type, TypeVar, Union
import uuid
-from fastapi import WebSocket, Body, APIRouter
+from fastapi import WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
+from pydantic import BaseModel
import traceback
+import asyncio
-from ..libs.util.telemetry import capture_event
+from .meilisearch_server import start_meilisearch
+from ..libs.util.telemetry import posthog_logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem
from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit
-from pydantic import BaseModel
-from .gui import SessionManager, session_manager
+from .gui import session_manager
from .ide_protocol import AbstractIdeProtocolServer
-import asyncio
from ..libs.util.create_async_task import create_async_task
+from .session_manager import SessionManager
+
import nest_asyncio
nest_asyncio.apply()
@@ -138,6 +140,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
continue
message_type = message["messageType"]
data = message["data"]
+ print("Received message while initializing", message_type)
if message_type == "workspaceDirectory":
self.workspace_directory = data["workspaceDirectory"]
elif message_type == "uniqueId":
@@ -152,17 +155,18 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def _send_json(self, message_type: str, data: Any):
if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
+ print("Sending IDE message: ", message_type)
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
try:
return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
except asyncio.TimeoutError:
raise Exception(
- "IDE Protocol _receive_json timed out after 5 seconds")
+ "IDE Protocol _receive_json timed out after 20 seconds", message_type)
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -273,12 +277,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
# like file changes, tracebacks, etc...
def onAcceptRejectSuggestion(self, accepted: bool):
- capture_event(self.unique_id, "accept_reject_suggestion", {
+ posthog_logger.capture_event("accept_reject_suggestion", {
"accepted": accepted
})
def onAcceptRejectDiff(self, accepted: bool):
- capture_event(self.unique_id, "accept_reject_diff", {
+ posthog_logger.capture_event("accept_reject_diff", {
"accepted": accepted
})
@@ -431,6 +435,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
+ # Start meilisearch
+ try:
+ await start_meilisearch()
+ except Exception as e:
+ print("Failed to start MeiliSearch")
+ print(e)
+
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
@@ -443,6 +454,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
message_type = message["messageType"]
data = message["data"]
+ print("Received IDE message: ", message_type)
create_async_task(
ideProtocolServer.handle_json(message_type, data))
@@ -450,8 +462,8 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if session_id is not None:
session_manager.registered_ides[session_id] = ideProtocolServer
other_msgs = await ideProtocolServer.initialize(session_id)
- capture_event(ideProtocolServer.unique_id, "session_started", {
- "session_id": ideProtocolServer.session_id})
+ posthog_logger.capture_event("session_started", {
+ "session_id": ideProtocolServer.session_id})
for other_msg in other_msgs:
handle_msg(other_msg)
@@ -465,13 +477,14 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
- capture_event(ideProtocolServer.unique_id, "gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
+ posthog_logger.capture_event("gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
finally:
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
- capture_event(ideProtocolServer.unique_id, "session_ended", {
- "session_id": ideProtocolServer.session_id})
- session_manager.registered_ides.pop(ideProtocolServer.session_id)
+ posthog_logger.capture_event("session_ended", {
+ "session_id": ideProtocolServer.session_id})
+ if ideProtocolServer.session_id in session_manager.registered_ides:
+ session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 42dc0cc1..0b59d4fe 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,15 +1,17 @@
+import asyncio
import time
import psutil
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
-from .ide import router as ide_router
-from .gui import router as gui_router
-from .session_manager import session_manager
import atexit
import uvicorn
import argparse
+from .ide import router as ide_router
+from .gui import router as gui_router
+from .session_manager import session_manager
+
app = FastAPI()
app.include_router(ide_router)
@@ -41,15 +43,20 @@ args = parser.parse_args()
# log_file = open('output.log', 'a')
# sys.stdout = log_file
-
def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
-def cleanup():
+async def cleanup_coroutine():
print("Cleaning up sessions")
for session_id in session_manager.sessions:
- session_manager.persist_session(session_id)
+ await session_manager.persist_session(session_id)
+
+
+def cleanup():
+ loop = asyncio.new_event_loop()
+ loop.run_until_complete(cleanup_coroutine())
+ loop.close()
def cpu_usage_report():
@@ -79,5 +86,6 @@ if __name__ == "__main__":
run_server()
except Exception as e:
+ print("Error starting Continue server: ", e)
cleanup()
raise e
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
new file mode 100644
index 00000000..286019e1
--- /dev/null
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -0,0 +1,77 @@
+import os
+import shutil
+import subprocess
+
+from meilisearch_python_async import Client
+from ..libs.util.paths import getServerFolderPath
+
+
+def ensure_meilisearch_installed():
+ """
+ Checks if MeiliSearch is installed.
+ """
+ serverPath = getServerFolderPath()
+ meilisearchPath = os.path.join(serverPath, "meilisearch")
+ dumpsPath = os.path.join(serverPath, "dumps")
+ dataMsPath = os.path.join(serverPath, "data.ms")
+
+ paths = [meilisearchPath, dumpsPath, dataMsPath]
+
+ existing_paths = set()
+ non_existing_paths = set()
+ for path in paths:
+ if os.path.exists(path):
+ existing_paths.add(path)
+ else:
+ non_existing_paths.add(path)
+
+ if len(non_existing_paths) > 0:
+ # Clear the meilisearch binary
+ if meilisearchPath in existing_paths:
+ os.remove(meilisearchPath)
+ non_existing_paths.remove(meilisearchPath)
+
+ # Clear the existing directories
+ for p in existing_paths:
+ shutil.rmtree(p, ignore_errors=True)
+
+ # Download MeiliSearch
+ print("Downloading MeiliSearch...")
+ subprocess.run(
+ f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath)
+
+
+async def check_meilisearch_running() -> bool:
+ """
+ Checks if MeiliSearch is running.
+ """
+
+ try:
+ client = Client('http://localhost:7700')
+ resp = await client.health()
+ if resp["status"] != "available":
+ return False
+ return True
+ except Exception:
+ return False
+
+
+async def start_meilisearch():
+ """
+ Starts the MeiliSearch server, wait for it.
+ """
+
+ # Doesn't work on windows for now
+ if not os.name == "posix":
+ return
+
+ serverPath = getServerFolderPath()
+
+ # Check if MeiliSearch is installed, if not download
+ ensure_meilisearch_installed()
+
+ # Check if MeiliSearch is running
+ if not await check_meilisearch_running():
+ print("Starting MeiliSearch...")
+ subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL,
+ stderr=subprocess.STDOUT, close_fds=True, start_new_session=True)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 20219273..3136f1bf 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -74,7 +74,7 @@ class SessionManager:
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
- "state": autopilot.get_full_state().dict()
+ "state": state.dict()
})
autopilot.on_update(on_update)
@@ -84,9 +84,9 @@ class SessionManager:
def remove_session(self, session_id: str):
del self.sessions[session_id]
- def persist_session(self, session_id: str):
+ async def persist_session(self, session_id: str):
"""Save the session's FullState as a json file"""
- full_state = self.sessions[session_id].autopilot.get_full_state()
+ full_state = await self.sessions[session_id].autopilot.get_full_state()
if not os.path.exists(getSessionsFolderPath()):
os.mkdir(getSessionsFolderPath())
with open(getSessionFilePath(session_id), "w") as f: