diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-12 21:53:06 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-12 21:53:06 -0700 | 
| commit | 391764f1371dab06af30a29e10a826a516b69bb3 (patch) | |
| tree | a15ec97701ebe1ac2346c802964d64795cdaf008 /continuedev/src/continuedev/server | |
| parent | b3ab5bda368fcae690837f9ce8062dc7f17c6472 (diff) | |
| download | sncontinue-391764f1371dab06af30a29e10a826a516b69bb3.tar.gz sncontinue-391764f1371dab06af30a29e10a826a516b69bb3.tar.bz2 sncontinue-391764f1371dab06af30a29e10a826a516b69bb3.zip | |
persist state and reconnect automatically
Diffstat (limited to 'continuedev/src/continuedev/server')
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 25 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/main.py | 16 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 41 | 
5 files changed, 78 insertions, 24 deletions
| diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 21089f30..8f6f68f6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -31,12 +31,12 @@ class AppStatus:  Server.handle_exit = AppStatus.handle_exit -def session(x_continue_session_id: str = Header("anonymous")) -> Session: -    return session_manager.get_session(x_continue_session_id) +async def session(x_continue_session_id: str = Header("anonymous")) -> Session: +    return await session_manager.get_session(x_continue_session_id) -def websocket_session(session_id: str) -> Session: -    return session_manager.get_session(session_id) +async def websocket_session(session_id: str) -> Session: +    return await session_manager.get_session(session_id)  T = TypeVar("T", bound=BaseModel) @@ -199,4 +199,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          print("Closing gui websocket")          if websocket.client_state != WebSocketState.DISCONNECTED:              await websocket.close() + +        session_manager.persist_session(session.session_id)          session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 4645b49e..12a21f19 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -52,9 +52,11 @@ class FileEditsUpdate(BaseModel):  class OpenFilesResponse(BaseModel):      openFiles: List[str] +  class VisibleFilesResponse(BaseModel):      visibleFiles: List[str] +  class HighlightedCodeResponse(BaseModel):      highlightedCode: List[RangeInFile] @@ -115,6 +117,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      websocket: WebSocket      session_manager: SessionManager      sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() +    session_id: Union[str, None] = None      def __init__(self, session_manager: SessionManager, websocket: WebSocket):          self.websocket = websocket @@ -132,8 +135,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer):                  continue              message_type = message["messageType"]              data = message["data"] -            # if message_type == "openGUI": -            #     await self.openGUI()              if message_type == "workspaceDirectory":                  self.workspace_directory = data["workspaceDirectory"]                  break @@ -158,8 +159,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          return resp_model.parse_obj(resp)      async def handle_json(self, message_type: str, data: Any): -        if message_type == "openGUI": -            await self.openGUI() +        if message_type == "getSessionId": +            await self.getSessionId()          elif message_type == "setFileOpen":              await self.setFileOpen(data["filepath"], data["open"])          elif message_type == "setSuggestionsLocked": @@ -217,9 +218,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              "locked": locked          }) -    async def openGUI(self): -        session_id = self.session_manager.new_session(self) -        await self._send_json("openGUI", { +    async def getSessionId(self): +        session_id = self.session_manager.new_session( +            self, self.session_id).session_id +        await self._send_json("getSessionId", {              "sessionId": session_id          }) @@ -304,7 +306,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      async def getOpenFiles(self) -> List[str]:          resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")          return resp.openFiles -     +      async def getVisibleFiles(self) -> List[str]:          resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")          return resp.visibleFiles @@ -416,7 +418,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):  @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint(websocket: WebSocket, session_id: str = None):      try:          await websocket.accept()          print("Accepted websocket connection from, ", websocket.client) @@ -434,6 +436,9 @@ async def websocket_endpoint(websocket: WebSocket):                  ideProtocolServer.handle_json(message_type, data))          ideProtocolServer = IdeProtocolServer(session_manager, websocket) +        ideProtocolServer.session_id = session_id +        if session_id is not None: +            session_manager.registered_ides[session_id] = ideProtocolServer          other_msgs = await ideProtocolServer.initialize()          for other_msg in other_msgs: @@ -454,3 +459,5 @@ async def websocket_endpoint(websocket: WebSocket):      finally:          if websocket.client_state != WebSocketState.DISCONNECTED:              await websocket.close() + +        session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 2783dc61..2f78cf0e 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, List, Union  from abc import ABC, abstractmethod, abstractproperty +from fastapi import WebSocket  from ..models.main import Traceback  from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff @@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents  class AbstractIdeProtocolServer(ABC): +    websocket: WebSocket +    session_id: Union[str, None] +      @abstractmethod      async def handle_json(self, data: Any):          """Handle a json message""" @@ -24,8 +28,8 @@ class AbstractIdeProtocolServer(ABC):          """Set whether suggestions are locked"""      @abstractmethod -    async def openGUI(self): -        """Open a GUI""" +    async def getSessionId(self): +        """Get a new session ID"""      @abstractmethod      async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f4d82903..aa093853 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -4,7 +4,8 @@ from fastapi import FastAPI  from fastapi.middleware.cors import CORSMiddleware  from .ide import router as ide_router  from .gui import router as gui_router -import logging +from .session_manager import session_manager +import atexit  import uvicorn  import argparse @@ -44,5 +45,16 @@ def run_server():      uvicorn.run(app, host="0.0.0.0", port=args.port) +def cleanup(): +    print("Cleaning up sessions") +    for session_id in session_manager.sessions: +        session_manager.persist_session(session_id) + + +atexit.register(cleanup)  if __name__ == "__main__": -    run_server() +    try: +        run_server() +    except Exception as e: +        cleanup() +        raise e diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 7147dcfa..fb8ac386 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,9 +1,12 @@ -from asyncio import BaseEventLoop +import os  from fastapi import WebSocket  from typing import Any, Dict, List, Union  from uuid import uuid4 +import json +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath  from ..models.filesystem_edit import FileEditWithFullContents +from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER  from ..core.policy import DemoPolicy  from ..core.main import FullState  from ..core.autopilot import Autopilot @@ -39,17 +42,35 @@ class DemoAutopilot(Autopilot):  class SessionManager:      sessions: Dict[str, Session] = {} +    # Mapping of session_id to IDE, where the IDE is still alive +    registered_ides: Dict[str, AbstractIdeProtocolServer] = {} -    def get_session(self, session_id: str) -> Session: +    async def get_session(self, session_id: str) -> Session:          if session_id not in self.sessions: +            # Check then whether it is persisted by listing all files in the sessions folder +            # And only if the IDE is still alive +            sessions_folder = getSessionsFolderPath() +            session_files = os.listdir(sessions_folder) +            if f"{session_id}.json" in session_files and session_id in self.registered_ides: +                if self.registered_ides[session_id].session_id is not None: +                    return self.new_session(self.registered_ides[session_id], session_id=session_id) +              raise KeyError("Session ID not recognized", session_id)          return self.sessions[session_id] -    def new_session(self, ide: AbstractIdeProtocolServer) -> str: -        autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide) -        session_id = str(uuid4()) +    def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: +        full_state = None +        if session_id is not None and os.path.exists(getSessionFilePath(session_id)): +            with open(getSessionFilePath(session_id), "r") as f: +                full_state = FullState(**json.load(f)) + +        autopilot = DemoAutopilot( +            policy=DemoPolicy(), ide=ide, full_state=full_state) +        session_id = session_id or str(uuid4()) +        ide.session_id = session_id          session = Session(session_id=session_id, autopilot=autopilot)          self.sessions[session_id] = session +        self.registered_ides[session_id] = ide          async def on_update(state: FullState):              await session_manager.send_ws_data(session_id, "state_update", { @@ -58,11 +79,19 @@ class SessionManager:          autopilot.on_update(on_update)          create_async_task(autopilot.run_policy()) -        return session_id +        return session      def remove_session(self, session_id: str):          del self.sessions[session_id] +    def persist_session(self, session_id: str): +        """Save the session's FullState as a json file""" +        full_state = self.sessions[session_id].autopilot.get_full_state() +        if not os.path.exists(getSessionsFolderPath()): +            os.mkdir(getSessionsFolderPath()) +        with open(getSessionFilePath(session_id), "w") as f: +            json.dump(full_state.dict(), f) +      def register_websocket(self, session_id: str, ws: WebSocket):          self.sessions[session_id].ws = ws          print("Registered websocket for session", session_id) | 
