summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-12 21:53:06 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-12 21:53:06 -0700
commitc6eff59c445017066cae8d2706521a694ef16a23 (patch)
tree3ecfb49dd0bc3ba20a78b685e503776f89dbf810 /continuedev
parentd78cb7b1e09bb9ff22fc9e3323ec3b18e03dbcbf (diff)
downloadsncontinue-c6eff59c445017066cae8d2706521a694ef16a23.tar.gz
sncontinue-c6eff59c445017066cae8d2706521a694ef16a23.tar.bz2
sncontinue-c6eff59c445017066cae8d2706521a694ef16a23.zip
persist state and reconnect automatically
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py17
-rw-r--r--continuedev/src/continuedev/libs/constants/main.py6
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py17
-rw-r--r--continuedev/src/continuedev/server/gui.py10
-rw-r--r--continuedev/src/continuedev/server/ide.py25
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py10
-rw-r--r--continuedev/src/continuedev/server/main.py16
-rw-r--r--continuedev/src/continuedev/server/session_manager.py41
8 files changed, 114 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 1b074435..e1c8a076 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,13 +1,13 @@
from functools import cached_property
import traceback
import time
-from typing import Any, Callable, Coroutine, Dict, List
+from typing import Any, Callable, Coroutine, Dict, List, Union
import os
from aiohttp import ClientPayloadError
+from pydantic import root_validator
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.llm import LLM
from .observation import Observation, InternalErrorObservation
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -16,7 +16,6 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol
from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
from ..libs.util.telemetry import capture_event
from .sdk import ContinueSDK
-import asyncio
from ..libs.util.step_name_to_steps import get_step_from_name
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
from openai import error as openai_errors
@@ -46,6 +45,7 @@ class Autopilot(ContinueBaseModel):
ide: AbstractIdeProtocolServer
history: History = History.from_empty()
context: Context = Context()
+ full_state: Union[FullState, None] = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
_active: bool = False
@@ -63,8 +63,15 @@ class Autopilot(ContinueBaseModel):
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
+ @root_validator(pre=True)
+ def fill_in_values(cls, values):
+ full_state: FullState = values.get('full_state')
+ if full_state is not None:
+ values['history'] = full_state.history
+ return values
+
def get_full_state(self) -> FullState:
- return FullState(
+ full_state = FullState(
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
@@ -73,6 +80,8 @@ class Autopilot(ContinueBaseModel):
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self._adding_highlighted_code,
)
+ self.full_state = full_state
+ return full_state
def get_available_slash_commands(self) -> List[Dict]:
custom_commands = list(map(lambda x: {
diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py
new file mode 100644
index 00000000..96eb6e69
--- /dev/null
+++ b/continuedev/src/continuedev/libs/constants/main.py
@@ -0,0 +1,6 @@
+## PATHS ##
+
+CONTINUE_GLOBAL_FOLDER = ".continue"
+CONTINUE_SESSIONS_FOLDER = "sessions"
+CONTINUE_SERVER_FOLDER = "server"
+
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
new file mode 100644
index 00000000..fddef887
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -0,0 +1,17 @@
+import os
+
+from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER
+
+def getGlobalFolderPath():
+ return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER)
+
+
+
+def getSessionsFolderPath():
+ return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER)
+
+def getServerFolderPath():
+ return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER)
+
+def getSessionFilePath(session_id: str):
+ return os.path.join(getSessionsFolderPath(), f"{session_id}.json") \ No newline at end of file
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 21089f30..8f6f68f6 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -31,12 +31,12 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+async def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return await session_manager.get_session(x_continue_session_id)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+async def websocket_session(session_id: str) -> Session:
+ return await session_manager.get_session(session_id)
T = TypeVar("T", bound=BaseModel)
@@ -199,4 +199,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
print("Closing gui websocket")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
+
+ session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 4645b49e..12a21f19 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -52,9 +52,11 @@ class FileEditsUpdate(BaseModel):
class OpenFilesResponse(BaseModel):
openFiles: List[str]
+
class VisibleFilesResponse(BaseModel):
visibleFiles: List[str]
+
class HighlightedCodeResponse(BaseModel):
highlightedCode: List[RangeInFile]
@@ -115,6 +117,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
websocket: WebSocket
session_manager: SessionManager
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ session_id: Union[str, None] = None
def __init__(self, session_manager: SessionManager, websocket: WebSocket):
self.websocket = websocket
@@ -132,8 +135,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
continue
message_type = message["messageType"]
data = message["data"]
- # if message_type == "openGUI":
- # await self.openGUI()
if message_type == "workspaceDirectory":
self.workspace_directory = data["workspaceDirectory"]
break
@@ -158,8 +159,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
- if message_type == "openGUI":
- await self.openGUI()
+ if message_type == "getSessionId":
+ await self.getSessionId()
elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
elif message_type == "setSuggestionsLocked":
@@ -217,9 +218,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"locked": locked
})
- async def openGUI(self):
- session_id = self.session_manager.new_session(self)
- await self._send_json("openGUI", {
+ async def getSessionId(self):
+ session_id = self.session_manager.new_session(
+ self, self.session_id).session_id
+ await self._send_json("getSessionId", {
"sessionId": session_id
})
@@ -304,7 +306,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def getOpenFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
-
+
async def getVisibleFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")
return resp.visibleFiles
@@ -416,7 +418,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket):
+async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
@@ -434,6 +436,9 @@ async def websocket_endpoint(websocket: WebSocket):
ideProtocolServer.handle_json(message_type, data))
ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ ideProtocolServer.session_id = session_id
+ if session_id is not None:
+ session_manager.registered_ides[session_id] = ideProtocolServer
other_msgs = await ideProtocolServer.initialize()
for other_msg in other_msgs:
@@ -454,3 +459,5 @@ async def websocket_endpoint(websocket: WebSocket):
finally:
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
+
+ session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index 2783dc61..2f78cf0e 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -1,5 +1,6 @@
-from typing import Any, List
+from typing import Any, List, Union
from abc import ABC, abstractmethod, abstractproperty
+from fastapi import WebSocket
from ..models.main import Traceback
from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff
@@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents
class AbstractIdeProtocolServer(ABC):
+ websocket: WebSocket
+ session_id: Union[str, None]
+
@abstractmethod
async def handle_json(self, data: Any):
"""Handle a json message"""
@@ -24,8 +28,8 @@ class AbstractIdeProtocolServer(ABC):
"""Set whether suggestions are locked"""
@abstractmethod
- async def openGUI(self):
- """Open a GUI"""
+ async def getSessionId(self):
+ """Get a new session ID"""
@abstractmethod
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index f4d82903..aa093853 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -4,7 +4,8 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
from .gui import router as gui_router
-import logging
+from .session_manager import session_manager
+import atexit
import uvicorn
import argparse
@@ -44,5 +45,16 @@ def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
+def cleanup():
+ print("Cleaning up sessions")
+ for session_id in session_manager.sessions:
+ session_manager.persist_session(session_id)
+
+
+atexit.register(cleanup)
if __name__ == "__main__":
- run_server()
+ try:
+ run_server()
+ except Exception as e:
+ cleanup()
+ raise e
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 7147dcfa..fb8ac386 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,9 +1,12 @@
-from asyncio import BaseEventLoop
+import os
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
+import json
+from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
+from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
from ..core.policy import DemoPolicy
from ..core.main import FullState
from ..core.autopilot import Autopilot
@@ -39,17 +42,35 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
+ # Mapping of session_id to IDE, where the IDE is still alive
+ registered_ides: Dict[str, AbstractIdeProtocolServer] = {}
- def get_session(self, session_id: str) -> Session:
+ async def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
+ # Check then whether it is persisted by listing all files in the sessions folder
+ # And only if the IDE is still alive
+ sessions_folder = getSessionsFolderPath()
+ session_files = os.listdir(sessions_folder)
+ if f"{session_id}.json" in session_files and session_id in self.registered_ides:
+ if self.registered_ides[session_id].session_id is not None:
+ return self.new_session(self.registered_ides[session_id], session_id=session_id)
+
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide)
- session_id = str(uuid4())
+ def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ full_state = None
+ if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
+ with open(getSessionFilePath(session_id), "r") as f:
+ full_state = FullState(**json.load(f))
+
+ autopilot = DemoAutopilot(
+ policy=DemoPolicy(), ide=ide, full_state=full_state)
+ session_id = session_id or str(uuid4())
+ ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
+ self.registered_ides[session_id] = ide
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
@@ -58,11 +79,19 @@ class SessionManager:
autopilot.on_update(on_update)
create_async_task(autopilot.run_policy())
- return session_id
+ return session
def remove_session(self, session_id: str):
del self.sessions[session_id]
+ def persist_session(self, session_id: str):
+ """Save the session's FullState as a json file"""
+ full_state = self.sessions[session_id].autopilot.get_full_state()
+ if not os.path.exists(getSessionsFolderPath()):
+ os.mkdir(getSessionsFolderPath())
+ with open(getSessionFilePath(session_id), "w") as f:
+ json.dump(full_state.dict(), f)
+
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
print("Registered websocket for session", session_id)