summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
commit4c3a25a1c8938f8132233e021c74d98eb19d7ddd (patch)
tree8460e5703f224e7ef5c2c7eca6b470f338b93e1e /continuedev/src/continuedev/server
parent3ded151331933c9a1352cc46c3cc67c5733d1c86 (diff)
parenta4a815628f702af806603015ec6805edd151328b (diff)
downloadsncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.tar.gz
sncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.tar.bz2
sncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.zip
Merge branch 'main' into ggml-server
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py83
-rw-r--r--continuedev/src/continuedev/server/ide.py149
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py27
-rw-r--r--continuedev/src/continuedev/server/main.py16
-rw-r--r--continuedev/src/continuedev/server/session_manager.py51
-rw-r--r--continuedev/src/continuedev/server/state_manager.py21
6 files changed, 220 insertions, 127 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 8e9b1fb9..4201353e 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,15 +1,17 @@
+import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
+import traceback
from uvicorn.main import Server
from .session_manager import SessionManager, session_manager, Session
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.telemetry import capture_event
+from ..libs.util.create_async_task import create_async_task
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -30,12 +32,12 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+async def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return await session_manager.get_session(x_continue_session_id)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+async def websocket_session(session_id: str) -> Session:
+ return await session_manager.get_session(session_id)
T = TypeVar("T", bound=BaseModel)
@@ -52,13 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "GUI Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -102,51 +110,60 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_main_input(self, input: str):
# Do something with user input
- asyncio.create_task(self.session.autopilot.accept_user_input(input))
+ create_async_task(self.session.autopilot.accept_user_input(
+ input), self.session.autopilot.continue_sdk.ide.unique_id)
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- asyncio.create_task(self.session.autopilot.reverse_to_index(index))
+ create_async_task(self.session.autopilot.reverse_to_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_step_user_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.give_user_input(input, index))
+ create_async_task(
+ self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_refinement_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.accept_refinement_input(input, index))
+ create_async_task(
+ self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_retry_at_index(self, index: int):
- asyncio.create_task(
- self.session.autopilot.retry_at_index(index))
+ create_async_task(
+ self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_change_default_model(self, model: str):
- asyncio.create_task(self.session.autopilot.change_default_model(model))
+ create_async_task(self.session.autopilot.change_default_model(
+ model), self.session.autopilot.continue_sdk.ide.unique_id)
def on_clear_history(self):
- asyncio.create_task(self.session.autopilot.clear_history())
+ create_async_task(self.session.autopilot.clear_history(
+ ), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_at_index(self, index: int):
- asyncio.create_task(self.session.autopilot.delete_at_index(index))
+ create_async_task(self.session.autopilot.delete_at_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_context_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.delete_context_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.delete_context_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_toggle_adding_highlighted_code(self):
- asyncio.create_task(
- self.session.autopilot.toggle_adding_highlighted_code()
+ create_async_task(
+ self.session.autopilot.toggle_adding_highlighted_code(
+ ), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_editing_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_editing_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_editing_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_pinned_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_pinned_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_pinned_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
@@ -176,11 +193,17 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
-
+ except WebSocketDisconnect as e:
+ print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
+ capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
finally:
print("Closing gui websocket")
- await websocket.close()
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e4a6266a..a91708ec 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,7 +5,9 @@ import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
+import traceback
from ..libs.util.telemetry import capture_event
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -15,6 +17,7 @@ from pydantic import BaseModel
from .gui import SessionManager, session_manager
from .ide_protocol import AbstractIdeProtocolServer
import asyncio
+from ..libs.util.create_async_task import create_async_task
import nest_asyncio
nest_asyncio.apply()
@@ -50,6 +53,10 @@ class OpenFilesResponse(BaseModel):
openFiles: List[str]
+class VisibleFilesResponse(BaseModel):
+ visibleFiles: List[str]
+
+
class HighlightedCodeResponse(BaseModel):
highlightedCode: List[RangeInFile]
@@ -110,19 +117,52 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
websocket: WebSocket
session_manager: SessionManager
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ session_id: Union[str, None] = None
def __init__(self, session_manager: SessionManager, websocket: WebSocket):
self.websocket = websocket
self.session_manager = session_manager
+ workspace_directory: str = None
+ unique_id: str = None
+
+ async def initialize(self, session_id: str) -> List[str]:
+ self.session_id = session_id
+ await self._send_json("workspaceDirectory", {})
+ await self._send_json("uniqueId", {})
+ other_msgs = []
+ while True:
+ msg_string = await self.websocket.receive_text()
+ message = json.loads(msg_string)
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+ if message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
+ else:
+ other_msgs.append(msg_string)
+
+ if self.workspace_directory is not None and self.unique_id is not None:
+ break
+ return other_msgs
+
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "IDE Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -130,8 +170,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
- if message_type == "openGUI":
- await self.openGUI()
+ if message_type == "getSessionId":
+ await self.getSessionId()
elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
elif message_type == "setSuggestionsLocked":
@@ -154,8 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]:
self.sub_queue.post(message_type, data)
+ elif message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
else:
raise ValueError("Unknown message type", message_type)
@@ -187,9 +231,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"locked": locked
})
- async def openGUI(self):
- session_id = self.session_manager.new_session(self)
- await self._send_json("openGUI", {
+ async def getSessionId(self):
+ session_id = (await self.session_manager.new_session(
+ self, self.session_id)).session_id
+ await self._send_json("getSessionId", {
"sessionId": session_id
})
@@ -242,53 +287,40 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onOpenGUIRequest(self):
pass
+ def __get_autopilot(self):
+ return self.session_manager.sessions[self.session_id].autopilot
+
def onFileEdits(self, edits: List[FileEditWithFullContents]):
- # Send the file edits to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- session.autopilot.handle_manual_edits(edits)
+ if autopilot := self.__get_autopilot():
+ autopilot.handle_manual_edits(edits)
def onDeleteAtIndex(self, index: int):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(session.autopilot.delete_at_index(index))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.delete_at_index(index), self.unique_id)
def onCommandOutput(self, output: str):
- # Send the output to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_command_output(output))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.handle_command_output(output), self.unique_id)
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_highlighted_code(range_in_files))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.handle_highlighted_code(
+ range_in_files), self.unique_id)
def onMainUserInput(self, input: str):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.accept_user_input(input))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.accept_user_input(input), self.unique_id)
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
- async def getWorkspaceDirectory(self) -> str:
- resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
- return resp.workspaceDirectory
-
- async def get_unique_id(self) -> str:
- resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
- return resp.uniqueId
-
- @property
- def workspace_directory(self) -> str:
- return asyncio.run(self.getWorkspaceDirectory())
-
- @cached_property_no_none
- def unique_id(self) -> str:
- return asyncio.run(self.get_unique_id())
+ async def getVisibleFiles(self) -> List[str]:
+ resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")
+ return resp.visibleFiles
async def getHighlightedCode(self) -> List[RangeInFile]:
resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode")
@@ -389,28 +421,45 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket):
+async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- message = json.loads(message)
+ def handle_msg(msg):
+ message = json.loads(msg)
if "messageType" not in message or "data" not in message:
- continue
+ return
message_type = message["messageType"]
data = message["data"]
- await ideProtocolServer.handle_json(message_type, data)
+ create_async_task(
+ ideProtocolServer.handle_json(message_type, data))
+
+ ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ if session_id is not None:
+ session_manager.registered_ides[session_id] = ideProtocolServer
+ other_msgs = await ideProtocolServer.initialize(session_id)
+
+ for other_msg in other_msgs:
+ handle_msg(other_msg)
+
+ while AppStatus.should_exit is False:
+ message = await websocket.receive_text()
+ handle_msg(message)
print("Closing ide websocket")
- await websocket.close()
+ except WebSocketDisconnect as e:
+ print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
- await websocket.close()
+ capture_event(ideProtocolServer.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
+ finally:
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..d0fb0bf8 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -1,5 +1,6 @@
-from typing import Any, List
+from typing import Any, List, Union
from abc import ABC, abstractmethod, abstractproperty
+from fastapi import WebSocket
from ..models.main import Traceback
from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff
@@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents
class AbstractIdeProtocolServer(ABC):
+ websocket: WebSocket
+ session_id: Union[str, None]
+
@abstractmethod
async def handle_json(self, data: Any):
"""Handle a json message"""
@@ -16,10 +20,6 @@ class AbstractIdeProtocolServer(ABC):
"""Show a suggestion to the user"""
@abstractmethod
- async def getWorkspaceDirectory(self):
- """Get the workspace directory"""
-
- @abstractmethod
async def setFileOpen(self, filepath: str, open: bool = True):
"""Set whether a file is open"""
@@ -28,8 +28,8 @@ class AbstractIdeProtocolServer(ABC):
"""Set whether suggestions are locked"""
@abstractmethod
- async def openGUI(self):
- """Open a GUI"""
+ async def getSessionId(self):
+ """Get a new session ID"""
@abstractmethod
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
@@ -56,6 +56,10 @@ class AbstractIdeProtocolServer(ABC):
"""Get a list of open files"""
@abstractmethod
+ async def getVisibleFiles(self) -> List[str]:
+ """Get a list of visible files"""
+
+ @abstractmethod
async def getHighlightedCode(self) -> List[RangeInFile]:
"""Get a list of highlighted code"""
@@ -103,10 +107,5 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
- @abstractproperty
- def workspace_directory(self) -> str:
- """Get the workspace directory"""
-
- @abstractproperty
- def unique_id(self) -> str:
- """Get a unique ID for this IDE"""
+ workspace_directory: str
+ unique_id: str
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index f4d82903..aa093853 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -4,7 +4,8 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
from .gui import router as gui_router
-import logging
+from .session_manager import session_manager
+import atexit
import uvicorn
import argparse
@@ -44,5 +45,16 @@ def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
+def cleanup():
+ print("Cleaning up sessions")
+ for session_id in session_manager.sessions:
+ session_manager.persist_session(session_id)
+
+
+atexit.register(cleanup)
if __name__ == "__main__":
- run_server()
+ try:
+ run_server()
+ except Exception as e:
+ cleanup()
+ raise e
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 99a38146..6d109ca6 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,20 +1,24 @@
+import os
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
+import json
+from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
+from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
from ..core.policy import DemoPolicy
from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
class Session:
session_id: str
autopilot: Autopilot
+ # The GUI websocket for the session
ws: Union[WebSocket, None]
def __init__(self, session_id: str, autopilot: Autopilot):
@@ -38,18 +42,35 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
+ # Mapping of session_id to IDE, where the IDE is still alive
+ registered_ides: Dict[str, AbstractIdeProtocolServer] = {}
- def get_session(self, session_id: str) -> Session:
+ async def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
+ # Check then whether it is persisted by listing all files in the sessions folder
+ # And only if the IDE is still alive
+ sessions_folder = getSessionsFolderPath()
+ session_files = os.listdir(sessions_folder)
+ if f"{session_id}.json" in session_files and session_id in self.registered_ides:
+ if self.registered_ides[session_id].session_id is not None:
+ return await self.new_session(self.registered_ides[session_id], session_id=session_id)
+
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide)
- session_id = str(uuid4())
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ full_state = None
+ if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
+ with open(getSessionFilePath(session_id), "r") as f:
+ full_state = FullState(**json.load(f))
+
+ autopilot = await DemoAutopilot.create(
+ policy=DemoPolicy(), ide=ide, full_state=full_state)
+ session_id = session_id or str(uuid4())
+ ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
+ self.registered_ides[session_id] = ide
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
@@ -57,17 +78,27 @@ class SessionManager:
})
autopilot.on_update(on_update)
- asyncio.create_task(autopilot.run_policy())
- return session_id
+ create_async_task(autopilot.run_policy())
+ return session
def remove_session(self, session_id: str):
del self.sessions[session_id]
+ def persist_session(self, session_id: str):
+ """Save the session's FullState as a json file"""
+ full_state = self.sessions[session_id].autopilot.get_full_state()
+ if not os.path.exists(getSessionsFolderPath()):
+ os.mkdir(getSessionsFolderPath())
+ with open(getSessionFilePath(session_id), "w") as f:
+ json.dump(full_state.dict(), f)
+
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
print("Registered websocket for session", session_id)
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if session_id not in self.sessions:
+ raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
print(f"Session {session_id} has no websocket")
return
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
- """
- A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
- """
-
- def __init__(self, ws: WebSocket):
- self.ws = ws
-
- def _send_update(self, updates: List[StateUpdate]):
- self.ws.send_json(
- [update.dict() for update in updates]
- )