summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-24 01:00:42 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-24 01:00:42 -0700
commit85ce06beb9b2d587b0b572117a98318d226bed61 (patch)
tree76fc6232b5219d0bd61b547b26624641a99e7b9b /continuedev/src/continuedev/server
parent699a74250fd4cf91af930ff63077aeb81f74856f (diff)
parent885f88af1d7b35e03b1de4df3e74a60da1a777ed (diff)
downloadsncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.tar.gz
sncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.tar.bz2
sncontinue-85ce06beb9b2d587b0b572117a98318d226bed61.zip
Merge branch 'main' into show-react-immediately
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py20
-rw-r--r--continuedev/src/continuedev/server/ide.py84
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py9
-rw-r--r--continuedev/src/continuedev/server/main.py25
-rw-r--r--continuedev/src/continuedev/server/session_manager.py12
5 files changed, 101 insertions, 49 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 238273b2..ae57c0b6 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,3 +1,4 @@
+import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
@@ -53,15 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
- if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "GUI Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -94,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_set_editing_at_indices(data["indices"])
elif message_type == "set_pinned_at_indices":
self.on_set_pinned_at_indices(data["indices"])
+ elif message_type == "show_logs_at_index":
+ self.on_show_logs_at_index(data["index"])
except Exception as e:
print(e)
@@ -161,6 +168,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
indices), self.session.autopilot.continue_sdk.ide.unique_id
)
+ def on_show_logs_at_index(self, index: int):
+ name = f"continue_logs.txt"
+ logs = "\n\n############################################\n\n".join(
+ ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
+ create_async_task(
+ self.session.autopilot.ide.showVirtualFile(name, logs))
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 12a21f19..aeff5623 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -123,10 +123,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.websocket = websocket
self.session_manager = session_manager
- workspace_directory: str
+ workspace_directory: str = None
+ unique_id: str = None
- async def initialize(self) -> List[str]:
+ async def initialize(self, session_id: str) -> List[str]:
+ self.session_id = session_id
await self._send_json("workspaceDirectory", {})
+ await self._send_json("uniqueId", {})
other_msgs = []
while True:
msg_string = await self.websocket.receive_text()
@@ -137,21 +140,29 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
data = message["data"]
if message_type == "workspaceDirectory":
self.workspace_directory = data["workspaceDirectory"]
- break
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
else:
other_msgs.append(msg_string)
+
+ if self.workspace_directory is not None and self.unique_id is not None:
+ break
return other_msgs
async def _send_json(self, message_type: str, data: Any):
- if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "IDE Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -183,10 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]:
self.sub_queue.post(message_type, data)
elif message_type == "workspaceDirectory":
self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
else:
raise ValueError("Unknown message type", message_type)
@@ -211,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"open": open
})
+ async def showVirtualFile(self, name: str, contents: str):
+ await self._send_json("showVirtualFile", {
+ "name": name,
+ "contents": contents
+ })
+
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
# Lock suggestions in the file so they don't ruin the offset before others are inserted
await self._send_json("setSuggestionsLocked", {
@@ -219,8 +238,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
})
async def getSessionId(self):
- session_id = self.session_manager.new_session(
- self, self.session_id).session_id
+ session_id = (await self.session_manager.new_session(
+ self, self.session_id)).session_id
await self._send_json("getSessionId", {
"sessionId": session_id
})
@@ -274,33 +293,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onOpenGUIRequest(self):
pass
+ def __get_autopilot(self):
+ if self.session_id not in self.session_manager.sessions:
+ return None
+ return self.session_manager.sessions[self.session_id].autopilot
+
def onFileEdits(self, edits: List[FileEditWithFullContents]):
- # Send the file edits to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- session.autopilot.handle_manual_edits(edits)
+ if autopilot := self.__get_autopilot():
+ autopilot.handle_manual_edits(edits)
def onDeleteAtIndex(self, index: int):
- for _, session in self.session_manager.sessions.items():
- create_async_task(
- session.autopilot.delete_at_index(index), self.unique_id)
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.delete_at_index(index), self.unique_id)
def onCommandOutput(self, output: str):
- # Send the output to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
+ if autopilot := self.__get_autopilot():
create_async_task(
- session.autopilot.handle_command_output(output), self.unique_id)
+ autopilot.handle_command_output(output), self.unique_id)
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
- for _, session in self.session_manager.sessions.items():
- create_async_task(
- session.autopilot.handle_highlighted_code(range_in_files), self.unique_id)
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.handle_highlighted_code(
+ range_in_files), self.unique_id)
def onMainUserInput(self, input: str):
- for _, session in self.session_manager.sessions.items():
+ if autopilot := self.__get_autopilot():
create_async_task(
- session.autopilot.accept_user_input(input), self.unique_id)
+ autopilot.accept_user_input(input), self.unique_id)
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
@@ -311,14 +330,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")
return resp.visibleFiles
- async def get_unique_id(self) -> str:
- resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
- return resp.uniqueId
-
- @cached_property_no_none
- def unique_id(self) -> str:
- return asyncio.run(self.get_unique_id())
-
async def getHighlightedCode(self) -> List[RangeInFile]:
resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode")
return resp.highlightedCode
@@ -436,10 +447,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
ideProtocolServer.handle_json(message_type, data))
ideProtocolServer = IdeProtocolServer(session_manager, websocket)
- ideProtocolServer.session_id = session_id
if session_id is not None:
session_manager.registered_ides[session_id] = ideProtocolServer
- other_msgs = await ideProtocolServer.initialize()
+ other_msgs = await ideProtocolServer.initialize(session_id)
+ capture_event(ideProtocolServer.unique_id, "session_started", {
+ "session_id": ideProtocolServer.session_id})
for other_msg in other_msgs:
handle_msg(other_msg)
@@ -460,4 +472,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
+ capture_event(ideProtocolServer.unique_id, "session_ended", {
+ "session_id": ideProtocolServer.session_id})
session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index 2f78cf0e..0ae7e7fa 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC):
"""Set whether a file is open"""
@abstractmethod
+ async def showVirtualFile(self, name: str, contents: str):
+ """Show a virtual file"""
+
+ @abstractmethod
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
"""Set whether suggestions are locked"""
@@ -108,7 +112,4 @@ class AbstractIdeProtocolServer(ABC):
"""Show a diff"""
workspace_directory: str
-
- @abstractproperty
- def unique_id(self) -> str:
- """Get a unique ID for this IDE"""
+ unique_id: str
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index aa093853..42dc0cc1 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,5 +1,6 @@
+import time
+import psutil
import os
-import sys
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
@@ -51,9 +52,31 @@ def cleanup():
session_manager.persist_session(session_id)
+def cpu_usage_report():
+ process = psutil.Process(os.getpid())
+ # Call cpu_percent once to start measurement, but ignore the result
+ process.cpu_percent(interval=None)
+ # Wait for a short period of time
+ time.sleep(1)
+ # Call cpu_percent again to get the CPU usage over the interval
+ cpu_usage = process.cpu_percent(interval=None)
+ print(f"CPU usage: {cpu_usage}%")
+
+
atexit.register(cleanup)
+
if __name__ == "__main__":
try:
+ # import threading
+
+ # def cpu_usage_loop():
+ # while True:
+ # cpu_usage_report()
+ # time.sleep(2)
+
+ # cpu_thread = threading.Thread(target=cpu_usage_loop)
+ # cpu_thread.start()
+
run_server()
except Exception as e:
cleanup()
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index fb8ac386..20219273 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -7,7 +7,7 @@ import json
from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
-from ..core.policy import DemoPolicy
+from ..core.policy import DefaultPolicy
from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
@@ -53,19 +53,19 @@ class SessionManager:
session_files = os.listdir(sessions_folder)
if f"{session_id}.json" in session_files and session_id in self.registered_ides:
if self.registered_ides[session_id].session_id is not None:
- return self.new_session(self.registered_ides[session_id], session_id=session_id)
+ return await self.new_session(self.registered_ides[session_id], session_id=session_id)
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
full_state = None
if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
with open(getSessionFilePath(session_id), "r") as f:
full_state = FullState(**json.load(f))
- autopilot = DemoAutopilot(
- policy=DemoPolicy(), ide=ide, full_state=full_state)
+ autopilot = await DemoAutopilot.create(
+ policy=DefaultPolicy(), ide=ide, full_state=full_state)
session_id = session_id or str(uuid4())
ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
@@ -100,7 +100,7 @@ class SessionManager:
if session_id not in self.sessions:
raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
+ # print(f"Session {session_id} has no websocket")
return
await self.sessions[session_id].ws.send_json({