diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-30 22:30:00 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-30 22:30:00 -0700 |
commit | 57a572a420e16b08301f0c6738a1b414c59bce85 (patch) | |
tree | 2bdbc7831d66aafefe30a9e236ecc150d80024cc /continuedev/src/continuedev/server | |
parent | 1bc5777ed168e47e2ef2ab1b33eecf6cbd170a61 (diff) | |
parent | 8bd76be6c0925e0d5e5f6d239e9c6907df3cfd23 (diff) | |
download | sncontinue-57a572a420e16b08301f0c6738a1b414c59bce85.tar.gz sncontinue-57a572a420e16b08301f0c6738a1b414c59bce85.tar.bz2 sncontinue-57a572a420e16b08301f0c6738a1b414c59bce85.zip |
Merge remote-tracking branch 'continuedev/main' into llm-object-config-merge-main
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 121 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 49 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 64 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 26 |
5 files changed, 146 insertions, 119 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index c0957395..98a5aea0 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -8,10 +8,12 @@ import traceback from uvicorn.main import Server from .session_manager import session_manager, Session +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .gui_protocol import AbstractGUIProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..libs.util.create_async_task import create_async_task +from ..libs.util.logging import logger router = APIRouter(prefix="/gui", tags=["gui"]) @@ -25,17 +27,13 @@ class AppStatus: @staticmethod def handle_exit(*args, **kwargs): AppStatus.should_exit = True - print("Shutting down") + logger.debug("Shutting down") original_handler(*args, **kwargs) Server.handle_exit = AppStatus.handle_exit -async def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return await session_manager.get_session(x_continue_session_id) - - async def websocket_session(session_id: str) -> Session: return await session_manager.get_session(session_id) @@ -73,103 +71,97 @@ class GUIProtocolServer(AbstractGUIProtocolServer): resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) + def on_error(self, e: Exception): + return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def handle_json(self, message_type: str, data: Any): - try: - if message_type == "main_input": - self.on_main_input(data["input"]) - elif message_type == "step_user_input": - self.on_step_user_input(data["input"], data["index"]) - elif message_type == "refinement_input": - self.on_refinement_input(data["input"], data["index"]) - elif message_type == "reverse_to_index": - self.on_reverse_to_index(data["index"]) - elif message_type == "retry_at_index": - self.on_retry_at_index(data["index"]) - elif message_type == "clear_history": - self.on_clear_history() - elif message_type == "delete_at_index": - self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"]) - elif message_type == "toggle_adding_highlighted_code": - self.on_toggle_adding_highlighted_code() - elif message_type == "set_editing_at_indices": - self.on_set_editing_at_indices(data["indices"]) - elif message_type == "show_logs_at_index": - self.on_show_logs_at_index(data["index"]) - elif message_type == "select_context_item": - self.select_context_item(data["id"], data["query"]) - except Exception as e: - print(e) + if message_type == "main_input": + self.on_main_input(data["input"]) + elif message_type == "step_user_input": + self.on_step_user_input(data["input"], data["index"]) + elif message_type == "refinement_input": + self.on_refinement_input(data["input"], data["index"]) + elif message_type == "reverse_to_index": + self.on_reverse_to_index(data["index"]) + elif message_type == "retry_at_index": + self.on_retry_at_index(data["index"]) + elif message_type == "clear_history": + self.on_clear_history() + elif message_type == "delete_at_index": + self.on_delete_at_index(data["index"]) + elif message_type == "delete_context_with_ids": + self.on_delete_context_with_ids(data["ids"]) + elif message_type == "toggle_adding_highlighted_code": + self.on_toggle_adding_highlighted_code() + elif message_type == "set_editing_at_indices": + self.on_set_editing_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) + elif message_type == "select_context_item": + self.select_context_item(data["id"], data["query"]) def on_main_input(self, input: str): # Do something with user input - create_async_task(self.session.autopilot.accept_user_input( - input), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.accept_user_input(input), self.on_error) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - create_async_task(self.session.autopilot.reverse_to_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.reverse_to_index(index), self.on_error) def on_step_user_input(self, input: str, index: int): create_async_task( - self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.give_user_input(input, index), self.on_error) def on_refinement_input(self, input: str, index: int): create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.accept_refinement_input(input, index), self.on_error) def on_retry_at_index(self, index: int): create_async_task( - self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.retry_at_index(index), self.on_error) def on_clear_history(self): - create_async_task(self.session.autopilot.clear_history( - ), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.clear_history(), self.on_error) def on_delete_at_index(self, index: int): - create_async_task(self.session.autopilot.delete_at_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.delete_at_index(index), self.on_error) def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_with_ids( - ids), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.delete_context_with_ids(ids), self.on_error) def on_toggle_adding_highlighted_code(self): create_async_task( - self.session.autopilot.toggle_adding_highlighted_code( - ), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) def on_set_editing_at_indices(self, indices: List[int]): create_async_task( - self.session.autopilot.set_editing_at_indices( - indices), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.set_editing_at_indices(indices), self.on_error) def on_show_logs_at_index(self, index: int): name = f"continue_logs.txt" logs = "\n\n############################################\n\n".join( ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( - self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.select_context_item(id, query), self.on_error) @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): try: - print("Received websocket connection at url: ", websocket.url) + logger.debug(f"Received websocket connection at url: {websocket.url}") await websocket.accept() - print("Session started") + logger.debug("Session started") session_manager.register_websocket(session.session_id, websocket) protocol = GUIProtocolServer(session) protocol.websocket = websocket @@ -179,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we while AppStatus.should_exit is False: message = await websocket.receive_text() - print("Received GUI message", message) + logger.debug(f"Received GUI message {message}") if type(message) is str: message = json.loads(message) @@ -190,16 +182,21 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.handle_json(message_type, data) except WebSocketDisconnect as e: - print("GUI websocket disconnected") + logger.debug("GUI websocket disconnected") except Exception as e: - print("ERROR in gui websocket: ", e) + # Log, send to PostHog, and send to GUI + logger.debug(f"ERROR in gui websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await protocol.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: - print("Closing gui websocket") + logger.debug("Closing gui websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() await session_manager.persist_session(session.session_id) - session_manager.remove_session(session.session_id) + await session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 87374928..e4c07029 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import traceback import asyncio +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .meilisearch_server import start_meilisearch from ..libs.util.telemetry import posthog_logger from ..libs.util.queue import AsyncSubscriptionQueue @@ -19,6 +20,7 @@ from .gui import session_manager from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from .session_manager import SessionManager +from ..libs.util.logging import logger import nest_asyncio nest_asyncio.apply() @@ -37,7 +39,7 @@ class AppStatus: @staticmethod def handle_exit(*args, **kwargs): AppStatus.should_exit = True - print("Shutting down") + logger.debug("Shutting down") original_handler(*args, **kwargs) @@ -140,7 +142,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): continue message_type = message["messageType"] data = message["data"] - print("Received message while initializing", message_type) + logger.debug(f"Received message while initializing {message_type}") if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": @@ -154,9 +156,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: - print("Tried to send message, but websocket is disconnected", message_type) + logger.debug( + f"Tried to send message, but websocket is disconnected: {message_type}") return - print("Sending IDE message: ", message_type) + logger.debug(f"Sending IDE message: {message_type}") await self.websocket.send_json({ "messageType": message_type, "data": data @@ -167,7 +170,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) except asyncio.TimeoutError: raise Exception( - "IDE Protocol _receive_json timed out after 20 seconds", message_type) + f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -277,6 +280,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer): # This is where you might have triggers: plugins can subscribe to certian events # like file changes, tracebacks, etc... + def on_error(self, e: Exception): + return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def onAcceptRejectSuggestion(self, accepted: bool): posthog_logger.capture_event("accept_reject_suggestion", { "accepted": accepted @@ -309,22 +315,22 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): - create_async_task(autopilot.delete_at_index(index), self.unique_id) + create_async_task(autopilot.delete_at_index(index), self.on_error) def onCommandOutput(self, output: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.handle_command_output(output), self.unique_id) + autopilot.handle_command_output(output), self.on_error) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): if autopilot := self.__get_autopilot(): create_async_task(autopilot.handle_highlighted_code( - range_in_files), self.unique_id) + range_in_files), self.on_error) def onMainUserInput(self, input: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.accept_user_input(input), self.unique_id) + autopilot.accept_user_input(input), self.on_error) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: @@ -354,7 +360,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): }, GetUserSecretResponse, "getUserSecret") return resp.value except Exception as e: - print("Error getting user secret", e) + logger.debug(f"Error getting user secret: {e}") return "" async def saveFile(self, filepath: str): @@ -437,15 +443,15 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: await websocket.accept() - print("Accepted websocket connection from, ", websocket.client) + logger.debug(f"Accepted websocket connection from {websocket.client}") await websocket.send_json({"messageType": "connected", "data": {}}) # Start meilisearch try: await start_meilisearch() except Exception as e: - print("Failed to start MeiliSearch") - print(e) + logger.debug("Failed to start MeiliSearch") + logger.debug(e) def handle_msg(msg): message = json.loads(msg) @@ -455,9 +461,9 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): message_type = message["messageType"] data = message["data"] - print("Received IDE message: ", message_type) + logger.debug(f"Received IDE message: {message_type}") create_async_task( - ideProtocolServer.handle_json(message_type, data)) + ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error) ideProtocolServer = IdeProtocolServer(session_manager, websocket) if session_id is not None: @@ -473,15 +479,20 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): message = await websocket.receive_text() handle_msg(message) - print("Closing ide websocket") + logger.debug("Closing ide websocket") except WebSocketDisconnect as e: - print("IDE wbsocket disconnected") + logger.debug("IDE wbsocket disconnected") except Exception as e: - print("Error in ide websocket: ", e) + logger.debug(f"Error in ide websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: + logger.debug("Closing ide websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index b92c9fa3..468bc855 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -15,6 +15,7 @@ from .ide import router as ide_router from .gui import router as gui_router from .session_manager import session_manager from ..libs.util.paths import getLogFilePath +from ..libs.util.logging import logger app = FastAPI() @@ -33,45 +34,49 @@ app.add_middleware( @app.get("/health") def health(): - print("Testing") + logger.debug("Health check") return {"status": "ok"} -# add cli arg for server port -parser = argparse.ArgumentParser() -parser.add_argument("-p", "--port", help="server port", - type=int, default=65432) -args = parser.parse_args() - -log_path = getLogFilePath() -LOG_CONFIG = { - 'version': 1, - 'disable_existing_loggers': False, - 'handlers': { - 'file': { - 'level': 'DEBUG', - 'class': 'logging.FileHandler', - 'filename': log_path, - }, - }, - 'root': { - 'level': 'DEBUG', - 'handlers': ['file'] - } -} -print(f"Log path: {log_path}") +class Logger(object): + def __init__(self, log_file: str): + self.terminal = sys.stdout + self.log = open(log_file, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + def isatty(self): + return False + + +try: + # add cli arg for server port + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--port", help="server port", + type=int, default=65432) + args = parser.parse_args() +except Exception as e: + logger.debug(f"Error parsing command line arguments: {e}") + raise e def run_server(): - config = uvicorn.Config(app, host="0.0.0.0", - port=args.port, log_config=LOG_CONFIG) + config = uvicorn.Config(app, host="127.0.0.1", port=args.port) server = uvicorn.Server(config) server.run() async def cleanup_coroutine(): - print("Cleaning up sessions") + logger.debug("Cleaning up sessions") for session_id in session_manager.sessions: await session_manager.persist_session(session_id) @@ -90,13 +95,14 @@ def cpu_usage_report(): time.sleep(1) # Call cpu_percent again to get the CPU usage over the interval cpu_usage = process.cpu_percent(interval=None) - print(f"CPU usage: {cpu_usage}%") + logger.debug(f"CPU usage: {cpu_usage}%") atexit.register(cleanup) if __name__ == "__main__": try: + # Uncomment to get CPU usage reports # import threading # def cpu_usage_loop(): @@ -109,6 +115,6 @@ if __name__ == "__main__": run_server() except Exception as e: - print("Error starting Continue server: ", e) + logger.debug(f"Error starting Continue server: {e}") cleanup() raise e diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 00f692f5..7f460afc 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -5,6 +5,7 @@ import subprocess from meilisearch_python_async import Client from ..libs.util.paths import getServerFolderPath +from ..libs.util.logging import logger def ensure_meilisearch_installed() -> bool: @@ -39,7 +40,7 @@ def ensure_meilisearch_installed() -> bool: shutil.rmtree(p, ignore_errors=True) # Download MeiliSearch - print("Downloading MeiliSearch...") + logger.debug("Downloading MeiliSearch...") subprocess.run( f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) @@ -82,6 +83,6 @@ async def start_meilisearch(): # Check if MeiliSearch is running if not await check_meilisearch_running() or not was_already_installed: - print("Starting MeiliSearch...") + logger.debug("Starting MeiliSearch...") subprocess.Popen(["./meilisearch", "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 3136f1bf..cf46028f 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -4,6 +4,9 @@ from typing import Any, Dict, List, Union from uuid import uuid4 import json +from fastapi.websockets import WebSocketState + +from ..plugins.steps.core.core import DisplayErrorStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER @@ -13,6 +16,7 @@ from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from ..libs.util.errors import SessionNotFound +from ..libs.util.logging import logger class Session: @@ -59,6 +63,8 @@ class SessionManager: return self.sessions[session_id] async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + logger.debug(f"New session: {session_id}") + full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: @@ -78,29 +84,35 @@ class SessionManager: }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy()) + create_async_task(autopilot.run_policy( + ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) return session - def remove_session(self, session_id: str): - del self.sessions[session_id] + async def remove_session(self, session_id: str): + logger.debug(f"Removing session: {session_id}") + if session_id in self.sessions: + if session_id in self.registered_ides: + ws_to_close = self.registered_ides[session_id].websocket + if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: + await self.sessions[session_id].autopilot.ide.websocket.close() + + del self.sessions[session_id] async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" full_state = await self.sessions[session_id].autopilot.get_full_state() - if not os.path.exists(getSessionsFolderPath()): - os.mkdir(getSessionsFolderPath()) with open(getSessionFilePath(session_id), "w") as f: json.dump(full_state.dict(), f) def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws - print("Registered websocket for session", session_id) + logger.debug(f"Registered websocket for session {session_id}") async def send_ws_data(self, session_id: str, message_type: str, data: Any): if session_id not in self.sessions: raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: - # print(f"Session {session_id} has no websocket") + # logger.debug(f"Session {session_id} has no websocket") return await self.sessions[session_id].ws.send_json({ |