diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-24 01:00:42 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-24 01:00:42 -0700 |
commit | 85ce06beb9b2d587b0b572117a98318d226bed61 (patch) | |
tree | 76fc6232b5219d0bd61b547b26624641a99e7b9b /continuedev/src/continuedev/server | |
parent | 699a74250fd4cf91af930ff63077aeb81f74856f (diff) | |
parent | 885f88af1d7b35e03b1de4df3e74a60da1a777ed (diff) | |
download | sncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.tar.gz sncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.tar.bz2 sncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.zip |
Merge branch 'main' into show-react-immediately
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 20 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 84 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 25 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 12 |
5 files changed, 101 insertions, 49 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 238273b2..ae57c0b6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,3 +1,4 @@ +import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect @@ -53,15 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.session = session async def _send_json(self, message_type: str, data: Any): - if self.websocket.client_state == WebSocketState.DISCONNECTED: + if self.websocket.application_state == WebSocketState.DISCONNECTED: return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "GUI Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -94,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_set_editing_at_indices(data["indices"]) elif message_type == "set_pinned_at_indices": self.on_set_pinned_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) except Exception as e: print(e) @@ -161,6 +168,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer): indices), self.session.autopilot.continue_sdk.ide.unique_id ) + def on_show_logs_at_index(self, index: int): + name = f"continue_logs.txt" + logs = "\n\n############################################\n\n".join( + ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) + create_async_task( + self.session.autopilot.ide.showVirtualFile(name, logs)) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 12a21f19..aeff5623 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -123,10 +123,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.websocket = websocket self.session_manager = session_manager - workspace_directory: str + workspace_directory: str = None + unique_id: str = None - async def initialize(self) -> List[str]: + async def initialize(self, session_id: str) -> List[str]: + self.session_id = session_id await self._send_json("workspaceDirectory", {}) + await self._send_json("uniqueId", {}) other_msgs = [] while True: msg_string = await self.websocket.receive_text() @@ -137,21 +140,29 @@ class IdeProtocolServer(AbstractIdeProtocolServer): data = message["data"] if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] - break + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] else: other_msgs.append(msg_string) + + if self.workspace_directory is not None and self.unique_id is not None: + break return other_msgs async def _send_json(self, message_type: str, data: Any): - if self.websocket.client_state == WebSocketState.DISCONNECTED: + if self.websocket.application_state == WebSocketState.DISCONNECTED: return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "IDE Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -183,10 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) - elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]: + elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]: self.sub_queue.post(message_type, data) elif message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] else: raise ValueError("Unknown message type", message_type) @@ -211,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "open": open }) + async def showVirtualFile(self, name: str, contents: str): + await self._send_json("showVirtualFile", { + "name": name, + "contents": contents + }) + async def setSuggestionsLocked(self, filepath: str, locked: bool = True): # Lock suggestions in the file so they don't ruin the offset before others are inserted await self._send_json("setSuggestionsLocked", { @@ -219,8 +238,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): }) async def getSessionId(self): - session_id = self.session_manager.new_session( - self, self.session_id).session_id + session_id = (await self.session_manager.new_session( + self, self.session_id)).session_id await self._send_json("getSessionId", { "sessionId": session_id }) @@ -274,33 +293,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onOpenGUIRequest(self): pass + def __get_autopilot(self): + if self.session_id not in self.session_manager.sessions: + return None + return self.session_manager.sessions[self.session_id].autopilot + def onFileEdits(self, edits: List[FileEditWithFullContents]): - # Send the file edits to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): - session.autopilot.handle_manual_edits(edits) + if autopilot := self.__get_autopilot(): + autopilot.handle_manual_edits(edits) def onDeleteAtIndex(self, index: int): - for _, session in self.session_manager.sessions.items(): - create_async_task( - session.autopilot.delete_at_index(index), self.unique_id) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.delete_at_index(index), self.unique_id) def onCommandOutput(self, output: str): - # Send the output to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): + if autopilot := self.__get_autopilot(): create_async_task( - session.autopilot.handle_command_output(output), self.unique_id) + autopilot.handle_command_output(output), self.unique_id) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): - for _, session in self.session_manager.sessions.items(): - create_async_task( - session.autopilot.handle_highlighted_code(range_in_files), self.unique_id) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.handle_highlighted_code( + range_in_files), self.unique_id) def onMainUserInput(self, input: str): - for _, session in self.session_manager.sessions.items(): + if autopilot := self.__get_autopilot(): create_async_task( - session.autopilot.accept_user_input(input), self.unique_id) + autopilot.accept_user_input(input), self.unique_id) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: @@ -311,14 +330,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer): resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles") return resp.visibleFiles - async def get_unique_id(self) -> str: - resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId") - return resp.uniqueId - - @cached_property_no_none - def unique_id(self) -> str: - return asyncio.run(self.get_unique_id()) - async def getHighlightedCode(self) -> List[RangeInFile]: resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode") return resp.highlightedCode @@ -436,10 +447,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): ideProtocolServer.handle_json(message_type, data)) ideProtocolServer = IdeProtocolServer(session_manager, websocket) - ideProtocolServer.session_id = session_id if session_id is not None: session_manager.registered_ides[session_id] = ideProtocolServer - other_msgs = await ideProtocolServer.initialize() + other_msgs = await ideProtocolServer.initialize(session_id) + capture_event(ideProtocolServer.unique_id, "session_started", { + "session_id": ideProtocolServer.session_id}) for other_msg in other_msgs: handle_msg(other_msg) @@ -460,4 +472,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() + capture_event(ideProtocolServer.unique_id, "session_ended", { + "session_id": ideProtocolServer.session_id}) session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 2f78cf0e..0ae7e7fa 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC): """Set whether a file is open""" @abstractmethod + async def showVirtualFile(self, name: str, contents: str): + """Show a virtual file""" + + @abstractmethod async def setSuggestionsLocked(self, filepath: str, locked: bool = True): """Set whether suggestions are locked""" @@ -108,7 +112,4 @@ class AbstractIdeProtocolServer(ABC): """Show a diff""" workspace_directory: str - - @abstractproperty - def unique_id(self) -> str: - """Get a unique ID for this IDE""" + unique_id: str diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index aa093853..42dc0cc1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,5 +1,6 @@ +import time +import psutil import os -import sys from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router @@ -51,9 +52,31 @@ def cleanup(): session_manager.persist_session(session_id) +def cpu_usage_report(): + process = psutil.Process(os.getpid()) + # Call cpu_percent once to start measurement, but ignore the result + process.cpu_percent(interval=None) + # Wait for a short period of time + time.sleep(1) + # Call cpu_percent again to get the CPU usage over the interval + cpu_usage = process.cpu_percent(interval=None) + print(f"CPU usage: {cpu_usage}%") + + atexit.register(cleanup) + if __name__ == "__main__": try: + # import threading + + # def cpu_usage_loop(): + # while True: + # cpu_usage_report() + # time.sleep(2) + + # cpu_thread = threading.Thread(target=cpu_usage_loop) + # cpu_thread.start() + run_server() except Exception as e: cleanup() diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index fb8ac386..20219273 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -7,7 +7,7 @@ import json from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DemoPolicy +from ..core.policy import DefaultPolicy from ..core.main import FullState from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer @@ -53,19 +53,19 @@ class SessionManager: session_files = os.listdir(sessions_folder) if f"{session_id}.json" in session_files and session_id in self.registered_ides: if self.registered_ides[session_id].session_id is not None: - return self.new_session(self.registered_ides[session_id], session_id=session_id) + return await self.new_session(self.registered_ides[session_id], session_id=session_id) raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = DemoAutopilot( - policy=DemoPolicy(), ide=ide, full_state=full_state) + autopilot = await DemoAutopilot.create( + policy=DefaultPolicy(), ide=ide, full_state=full_state) session_id = session_id or str(uuid4()) ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) @@ -100,7 +100,7 @@ class SessionManager: if session_id not in self.sessions: raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: - print(f"Session {session_id} has no websocket") + # print(f"Session {session_id} has no websocket") return await self.sessions[session_id].ws.send_json({ |