summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-19 00:33:50 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-19 00:33:50 -0700
commit1b92180d4b7720bf1cf36dd63142760d421dabf8 (patch)
tree26e25e005b06526267c2a140c1fbf1cbf822f066 /continuedev/src/continuedev/server
parent924a0c09259d25a4dfe62c0a626a9204df45daa9 (diff)
parenta7c57e1d1e4a0eff3e4b598f8bf0448ea6068353 (diff)
downloadsncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.tar.gz
sncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.tar.bz2
sncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.zip
Merge branch 'main' into config-py
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py92
-rw-r--r--continuedev/src/continuedev/server/ide.py161
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py31
-rw-r--r--continuedev/src/continuedev/server/main.py41
-rw-r--r--continuedev/src/continuedev/server/session_manager.py53
-rw-r--r--continuedev/src/continuedev/server/state_manager.py21
6 files changed, 270 insertions, 129 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 8e9b1fb9..ae57c0b6 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,15 +1,17 @@
+import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
+import traceback
from uvicorn.main import Server
from .session_manager import SessionManager, session_manager, Session
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.telemetry import capture_event
+from ..libs.util.create_async_task import create_async_task
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -30,12 +32,12 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+async def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return await session_manager.get_session(x_continue_session_id)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+async def websocket_session(session_id: str) -> Session:
+ return await session_manager.get_session(session_id)
T = TypeVar("T", bound=BaseModel)
@@ -52,13 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "GUI Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -91,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_set_editing_at_indices(data["indices"])
elif message_type == "set_pinned_at_indices":
self.on_set_pinned_at_indices(data["indices"])
+ elif message_type == "show_logs_at_index":
+ self.on_show_logs_at_index(data["index"])
except Exception as e:
print(e)
@@ -102,53 +112,69 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_main_input(self, input: str):
# Do something with user input
- asyncio.create_task(self.session.autopilot.accept_user_input(input))
+ create_async_task(self.session.autopilot.accept_user_input(
+ input), self.session.autopilot.continue_sdk.ide.unique_id)
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- asyncio.create_task(self.session.autopilot.reverse_to_index(index))
+ create_async_task(self.session.autopilot.reverse_to_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_step_user_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.give_user_input(input, index))
+ create_async_task(
+ self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_refinement_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.accept_refinement_input(input, index))
+ create_async_task(
+ self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_retry_at_index(self, index: int):
- asyncio.create_task(
- self.session.autopilot.retry_at_index(index))
+ create_async_task(
+ self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_change_default_model(self, model: str):
- asyncio.create_task(self.session.autopilot.change_default_model(model))
+ create_async_task(self.session.autopilot.change_default_model(
+ model), self.session.autopilot.continue_sdk.ide.unique_id)
def on_clear_history(self):
- asyncio.create_task(self.session.autopilot.clear_history())
+ create_async_task(self.session.autopilot.clear_history(
+ ), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_at_index(self, index: int):
- asyncio.create_task(self.session.autopilot.delete_at_index(index))
+ create_async_task(self.session.autopilot.delete_at_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_context_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.delete_context_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.delete_context_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_toggle_adding_highlighted_code(self):
- asyncio.create_task(
- self.session.autopilot.toggle_adding_highlighted_code()
+ create_async_task(
+ self.session.autopilot.toggle_adding_highlighted_code(
+ ), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_editing_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_editing_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_editing_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_pinned_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_pinned_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_pinned_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
+ def on_show_logs_at_index(self, index: int):
+ name = f"continue_logs.txt"
+ logs = "\n\n############################################\n\n".join(
+ ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
+ create_async_task(
+ self.session.autopilot.ide.showVirtualFile(name, logs))
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
@@ -176,11 +202,17 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
-
+ except WebSocketDisconnect as e:
+ print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
+ capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
finally:
print("Closing gui websocket")
- await websocket.close()
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e4a6266a..aeff5623 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,7 +5,9 @@ import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
+import traceback
from ..libs.util.telemetry import capture_event
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -15,6 +17,7 @@ from pydantic import BaseModel
from .gui import SessionManager, session_manager
from .ide_protocol import AbstractIdeProtocolServer
import asyncio
+from ..libs.util.create_async_task import create_async_task
import nest_asyncio
nest_asyncio.apply()
@@ -50,6 +53,10 @@ class OpenFilesResponse(BaseModel):
openFiles: List[str]
+class VisibleFilesResponse(BaseModel):
+ visibleFiles: List[str]
+
+
class HighlightedCodeResponse(BaseModel):
highlightedCode: List[RangeInFile]
@@ -110,19 +117,52 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
websocket: WebSocket
session_manager: SessionManager
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ session_id: Union[str, None] = None
def __init__(self, session_manager: SessionManager, websocket: WebSocket):
self.websocket = websocket
self.session_manager = session_manager
+ workspace_directory: str = None
+ unique_id: str = None
+
+ async def initialize(self, session_id: str) -> List[str]:
+ self.session_id = session_id
+ await self._send_json("workspaceDirectory", {})
+ await self._send_json("uniqueId", {})
+ other_msgs = []
+ while True:
+ msg_string = await self.websocket.receive_text()
+ message = json.loads(msg_string)
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+ if message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
+ else:
+ other_msgs.append(msg_string)
+
+ if self.workspace_directory is not None and self.unique_id is not None:
+ break
+ return other_msgs
+
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "IDE Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -130,8 +170,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
- if message_type == "openGUI":
- await self.openGUI()
+ if message_type == "getSessionId":
+ await self.getSessionId()
elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
elif message_type == "setSuggestionsLocked":
@@ -154,8 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]:
self.sub_queue.post(message_type, data)
+ elif message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
else:
raise ValueError("Unknown message type", message_type)
@@ -180,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"open": open
})
+ async def showVirtualFile(self, name: str, contents: str):
+ await self._send_json("showVirtualFile", {
+ "name": name,
+ "contents": contents
+ })
+
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
# Lock suggestions in the file so they don't ruin the offset before others are inserted
await self._send_json("setSuggestionsLocked", {
@@ -187,9 +237,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"locked": locked
})
- async def openGUI(self):
- session_id = self.session_manager.new_session(self)
- await self._send_json("openGUI", {
+ async def getSessionId(self):
+ session_id = (await self.session_manager.new_session(
+ self, self.session_id)).session_id
+ await self._send_json("getSessionId", {
"sessionId": session_id
})
@@ -242,53 +293,42 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onOpenGUIRequest(self):
pass
+ def __get_autopilot(self):
+ if self.session_id not in self.session_manager.sessions:
+ return None
+ return self.session_manager.sessions[self.session_id].autopilot
+
def onFileEdits(self, edits: List[FileEditWithFullContents]):
- # Send the file edits to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- session.autopilot.handle_manual_edits(edits)
+ if autopilot := self.__get_autopilot():
+ autopilot.handle_manual_edits(edits)
def onDeleteAtIndex(self, index: int):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(session.autopilot.delete_at_index(index))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.delete_at_index(index), self.unique_id)
def onCommandOutput(self, output: str):
- # Send the output to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_command_output(output))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.handle_command_output(output), self.unique_id)
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_highlighted_code(range_in_files))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.handle_highlighted_code(
+ range_in_files), self.unique_id)
def onMainUserInput(self, input: str):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.accept_user_input(input))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.accept_user_input(input), self.unique_id)
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
- async def getWorkspaceDirectory(self) -> str:
- resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
- return resp.workspaceDirectory
-
- async def get_unique_id(self) -> str:
- resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
- return resp.uniqueId
-
- @property
- def workspace_directory(self) -> str:
- return asyncio.run(self.getWorkspaceDirectory())
-
- @cached_property_no_none
- def unique_id(self) -> str:
- return asyncio.run(self.get_unique_id())
+ async def getVisibleFiles(self) -> List[str]:
+ resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")
+ return resp.visibleFiles
async def getHighlightedCode(self) -> List[RangeInFile]:
resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode")
@@ -389,28 +429,49 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket):
+async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- message = json.loads(message)
+ def handle_msg(msg):
+ message = json.loads(msg)
if "messageType" not in message or "data" not in message:
- continue
+ return
message_type = message["messageType"]
data = message["data"]
- await ideProtocolServer.handle_json(message_type, data)
+ create_async_task(
+ ideProtocolServer.handle_json(message_type, data))
+
+ ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ if session_id is not None:
+ session_manager.registered_ides[session_id] = ideProtocolServer
+ other_msgs = await ideProtocolServer.initialize(session_id)
+ capture_event(ideProtocolServer.unique_id, "session_started", {
+ "session_id": ideProtocolServer.session_id})
+
+ for other_msg in other_msgs:
+ handle_msg(other_msg)
+
+ while AppStatus.should_exit is False:
+ message = await websocket.receive_text()
+ handle_msg(message)
print("Closing ide websocket")
- await websocket.close()
+ except WebSocketDisconnect as e:
+ print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
- await websocket.close()
+ capture_event(ideProtocolServer.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
+ finally:
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ capture_event(ideProtocolServer.unique_id, "session_ended", {
+ "session_id": ideProtocolServer.session_id})
+ session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..0ae7e7fa 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -1,5 +1,6 @@
-from typing import Any, List
+from typing import Any, List, Union
from abc import ABC, abstractmethod, abstractproperty
+from fastapi import WebSocket
from ..models.main import Traceback
from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff
@@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents
class AbstractIdeProtocolServer(ABC):
+ websocket: WebSocket
+ session_id: Union[str, None]
+
@abstractmethod
async def handle_json(self, data: Any):
"""Handle a json message"""
@@ -16,20 +20,20 @@ class AbstractIdeProtocolServer(ABC):
"""Show a suggestion to the user"""
@abstractmethod
- async def getWorkspaceDirectory(self):
- """Get the workspace directory"""
-
- @abstractmethod
async def setFileOpen(self, filepath: str, open: bool = True):
"""Set whether a file is open"""
@abstractmethod
+ async def showVirtualFile(self, name: str, contents: str):
+ """Show a virtual file"""
+
+ @abstractmethod
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
"""Set whether suggestions are locked"""
@abstractmethod
- async def openGUI(self):
- """Open a GUI"""
+ async def getSessionId(self):
+ """Get a new session ID"""
@abstractmethod
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
@@ -56,6 +60,10 @@ class AbstractIdeProtocolServer(ABC):
"""Get a list of open files"""
@abstractmethod
+ async def getVisibleFiles(self) -> List[str]:
+ """Get a list of visible files"""
+
+ @abstractmethod
async def getHighlightedCode(self) -> List[RangeInFile]:
"""Get a list of highlighted code"""
@@ -103,10 +111,5 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
- @abstractproperty
- def workspace_directory(self) -> str:
- """Get the workspace directory"""
-
- @abstractproperty
- def unique_id(self) -> str:
- """Get a unique ID for this IDE"""
+ workspace_directory: str
+ unique_id: str
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index f4d82903..42dc0cc1 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,10 +1,12 @@
+import time
+import psutil
import os
-import sys
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
from .gui import router as gui_router
-import logging
+from .session_manager import session_manager
+import atexit
import uvicorn
import argparse
@@ -44,5 +46,38 @@ def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
+def cleanup():
+ print("Cleaning up sessions")
+ for session_id in session_manager.sessions:
+ session_manager.persist_session(session_id)
+
+
+def cpu_usage_report():
+ process = psutil.Process(os.getpid())
+ # Call cpu_percent once to start measurement, but ignore the result
+ process.cpu_percent(interval=None)
+ # Wait for a short period of time
+ time.sleep(1)
+ # Call cpu_percent again to get the CPU usage over the interval
+ cpu_usage = process.cpu_percent(interval=None)
+ print(f"CPU usage: {cpu_usage}%")
+
+
+atexit.register(cleanup)
+
if __name__ == "__main__":
- run_server()
+ try:
+ # import threading
+
+ # def cpu_usage_loop():
+ # while True:
+ # cpu_usage_report()
+ # time.sleep(2)
+
+ # cpu_thread = threading.Thread(target=cpu_usage_loop)
+ # cpu_thread.start()
+
+ run_server()
+ except Exception as e:
+ cleanup()
+ raise e
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 99a38146..90172a4e 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,20 +1,24 @@
+import os
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
+import json
+from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
+from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
from ..core.policy import DemoPolicy
from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
class Session:
session_id: str
autopilot: Autopilot
+ # The GUI websocket for the session
ws: Union[WebSocket, None]
def __init__(self, session_id: str, autopilot: Autopilot):
@@ -38,18 +42,35 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
+ # Mapping of session_id to IDE, where the IDE is still alive
+ registered_ides: Dict[str, AbstractIdeProtocolServer] = {}
- def get_session(self, session_id: str) -> Session:
+ async def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
+ # Check then whether it is persisted by listing all files in the sessions folder
+ # And only if the IDE is still alive
+ sessions_folder = getSessionsFolderPath()
+ session_files = os.listdir(sessions_folder)
+ if f"{session_id}.json" in session_files and session_id in self.registered_ides:
+ if self.registered_ides[session_id].session_id is not None:
+ return await self.new_session(self.registered_ides[session_id], session_id=session_id)
+
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide)
- session_id = str(uuid4())
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ full_state = None
+ if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
+ with open(getSessionFilePath(session_id), "r") as f:
+ full_state = FullState(**json.load(f))
+
+ autopilot = await DemoAutopilot.create(
+ policy=DemoPolicy(), ide=ide, full_state=full_state)
+ session_id = session_id or str(uuid4())
+ ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
+ self.registered_ides[session_id] = ide
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
@@ -57,19 +78,29 @@ class SessionManager:
})
autopilot.on_update(on_update)
- asyncio.create_task(autopilot.run_policy())
- return session_id
+ create_async_task(autopilot.run_policy())
+ return session
def remove_session(self, session_id: str):
del self.sessions[session_id]
+ def persist_session(self, session_id: str):
+ """Save the session's FullState as a json file"""
+ full_state = self.sessions[session_id].autopilot.get_full_state()
+ if not os.path.exists(getSessionsFolderPath()):
+ os.mkdir(getSessionsFolderPath())
+ with open(getSessionFilePath(session_id), "w") as f:
+ json.dump(full_state.dict(), f)
+
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
print("Registered websocket for session", session_id)
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if session_id not in self.sessions:
+ raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
+ # print(f"Session {session_id} has no websocket")
return
await self.sessions[session_id].ws.send_json({
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
- """
- A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
- """
-
- def __init__(self, ws: WebSocket):
- self.ws = ws
-
- def _send_update(self, updates: List[StateUpdate]):
- self.ws.send_json(
- [update.dict() for update in updates]
- )