summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-12 00:32:51 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-12 00:32:51 -0700
commit5593c72720dec57f8b734d442e7792fb30632d90 (patch)
tree9a2c1d41e9dab7867bf7a2910d91ecd268f1925d /continuedev/src/continuedev/server
parent4d7b530fb453383bc4082a4d946cb6a9d76178d3 (diff)
downloadsncontinue-5593c72720dec57f8b734d442e7792fb30632d90.tar.gz
sncontinue-5593c72720dec57f8b734d442e7792fb30632d90.tar.bz2
sncontinue-5593c72720dec57f8b734d442e7792fb30632d90.zip
finally found the snag
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py11
-rw-r--r--continuedev/src/continuedev/server/ide.py67
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py8
-rw-r--r--continuedev/src/continuedev/server/session_manager.py5
-rw-r--r--continuedev/src/continuedev/server/state_manager.py21
5 files changed, 62 insertions, 50 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index dbc063c8..21089f30 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,5 +1,6 @@
import json
from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
import traceback
@@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
@@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.websocket = websocket
# Update any history that may have happened before connection
- await protocol.send_state_update()
+ # await protocol.send_state_update()
while AppStatus.should_exit is False:
message = await websocket.receive_text()
@@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
-
+ except WebSocketDisconnect as e:
+ print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
@@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
raise e
finally:
print("Closing gui websocket")
- await websocket.close()
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 782c2ba6..400ad740 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,6 +5,7 @@ import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
import traceback
@@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager
from .ide_protocol import AbstractIdeProtocolServer
import asyncio
from ..libs.util.create_async_task import create_async_task
+import nest_asyncio
+nest_asyncio.apply()
router = APIRouter(prefix="/ide", tags=["ide"])
@@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.websocket = websocket
self.session_manager = session_manager
+ workspace_directory: str
+
+ async def initialize(self) -> List[str]:
+ await self._send_json("workspaceDirectory", {})
+ other_msgs = []
+ while True:
+ msg_string = await self.websocket.receive_text()
+ message = json.loads(msg_string)
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+ # if message_type == "openGUI":
+ # await self.openGUI()
+ if message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ break
+ else:
+ other_msgs.append(msg_string)
+ return other_msgs
+
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
@@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]:
self.sub_queue.post(message_type, data)
+ elif message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
else:
raise ValueError("Unknown message type", message_type)
@@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
- async def getWorkspaceDirectory(self) -> str:
- resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
- return resp.workspaceDirectory
-
async def get_unique_id(self) -> str:
resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
return resp.uniqueId
- @property
- def workspace_directory(self) -> str:
- return asyncio.run(self.getWorkspaceDirectory())
-
@cached_property_no_none
def unique_id(self) -> str:
return asyncio.run(self.get_unique_id())
@@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket):
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- message = json.loads(message)
+ def handle_msg(msg):
+ message = json.loads(msg)
if "messageType" not in message or "data" not in message:
- continue
+ return
message_type = message["messageType"]
data = message["data"]
- await ideProtocolServer.handle_json(message_type, data)
+ create_async_task(
+ ideProtocolServer.handle_json(message_type, data))
+
+ ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ other_msgs = await ideProtocolServer.initialize()
+
+ for other_msg in other_msgs:
+ handle_msg(other_msg)
+
+ while AppStatus.should_exit is False:
+ message = await websocket.receive_text()
+ handle_msg(message)
print("Closing ide websocket")
- await websocket.close()
+ except WebSocketDisconnect as e:
+ print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
capture_event(ideProtocolServer.unique_id, "gui_error", {
"error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
- await websocket.close()
raise e
+ finally:
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..69cb6c10 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -16,10 +16,6 @@ class AbstractIdeProtocolServer(ABC):
"""Show a suggestion to the user"""
@abstractmethod
- async def getWorkspaceDirectory(self):
- """Get the workspace directory"""
-
- @abstractmethod
async def setFileOpen(self, filepath: str, open: bool = True):
"""Set whether a file is open"""
@@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
- @abstractproperty
- def workspace_directory(self) -> str:
- """Get the workspace directory"""
+ workspace_directory: str
@abstractproperty
def unique_id(self) -> str:
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 873a379e..7147dcfa 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -9,11 +9,13 @@ from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
class Session:
session_id: str
autopilot: Autopilot
+ # The GUI websocket for the session
ws: Union[WebSocket, None]
def __init__(self, session_id: str, autopilot: Autopilot):
@@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
- _event_loop: Union[BaseEventLoop, None] = None
def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
@@ -67,6 +68,8 @@ class SessionManager:
print("Registered websocket for session", session_id)
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if session_id not in self.sessions:
+ raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
print(f"Session {session_id} has no websocket")
return
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
- """
- A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
- """
-
- def __init__(self, ws: WebSocket):
- self.ws = ws
-
- def _send_update(self, updates: List[StateUpdate]):
- self.ws.send_json(
- [update.dict() for update in updates]
- )