diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-08-06 22:54:51 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 22:54:51 -0700 |
commit | 74005855304412f0401e29c83c166e99a8ab0944 (patch) | |
tree | 4e8ddb61fd52068e839ca4ccab268e013405d545 /continuedev/src/continuedev/server | |
parent | a0d3f29ee237484c66b0efe243c79d902f2da993 (diff) | |
parent | 8ada89b0f66f9e746394ee64591359537fe0c7f0 (diff) | |
download | sncontinue-74005855304412f0401e29c83c166e99a8ab0944.tar.gz sncontinue-74005855304412f0401e29c83c166e99a8ab0944.tar.bz2 sncontinue-74005855304412f0401e29c83c166e99a8ab0944.zip |
Merge pull request #351 from continuedev/history
Session History
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 69 |
3 files changed, 63 insertions, 21 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 7c89c5c2..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from pydantic import BaseModel import traceback from uvicorn.main import Server @@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_show_logs_at_index(data["index"]) elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) + elif message_type == "load_session": + self.load_session(data.get("session_id", None)) def on_main_input(self, input: str): # Do something with user input @@ -154,6 +156,14 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) + def load_session(self, session_id: Optional[str] = None): + async def load_and_tell_to_reconnect(): + new_session_id = await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": new_session_id}) + + create_async_task( + load_and_tell_to_reconnect(), self.on_error) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f8dfb009..f0a3f094 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -11,13 +11,14 @@ import argparse from .ide import router as ide_router from .gui import router as gui_router -from .session_manager import session_manager +from .session_manager import session_manager, router as sessions_router from ..libs.util.logging import logger app = FastAPI() app.include_router(ide_router) app.include_router(gui_router) +app.include_router(sessions_router) # Add CORS support app.add_middleware( diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 56c92307..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,21 +1,23 @@ import os import traceback -from fastapi import WebSocket -from typing import Any, Coroutine, Dict, Union +from fastapi import WebSocket, APIRouter +from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 import json from fastapi.websockets import WebSocketState from ..plugins.steps.core.core import MessageStep -from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath -from ..core.main import FullState, HistoryNode +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath +from ..core.main import FullState, HistoryNode, SessionInfo from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from ..libs.util.errors import SessionNotFound from ..libs.util.logging import logger +router = APIRouter(prefix="/sessions", tags=["sessions"]) + class Session: session_id: str @@ -47,7 +49,7 @@ class SessionManager: raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -74,20 +76,9 @@ class SessionManager: # Start the autopilot (must be after session is added to sessions) and the policy try: - await autopilot.start() + await autopilot.start(full_state=full_state) except Exception as e: - # Have to manually add to history because autopilot isn't started - formatted_err = '\n'.join(traceback.format_exception(e)) - msg_step = MessageStep( - name="Error loading context manager", message=formatted_err) - msg_step.description = f"```\n{formatted_err}\n```" - autopilot.history.add_node(HistoryNode( - step=msg_step, - observation=None, - depth=0, - active=False - )) - logger.warning(f"Error loading context manager: {e}") + await self.on_error(e) def on_error(e: Exception) -> Coroutine: err_msg = '\n'.join(traceback.format_exception(e)) @@ -99,7 +90,7 @@ class SessionManager: async def remove_session(self, session_id: str): logger.debug(f"Removing session: {session_id}") if session_id in self.sessions: - if session_id in self.registered_ides: + if session_id in self.registered_ides and self.registered_ides[session_id] is not None: ws_to_close = self.registered_ides[session_id].websocket if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: await self.sessions[session_id].autopilot.ide.websocket.close() @@ -109,9 +100,37 @@ class SessionManager: async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" full_state = await self.sessions[session_id].autopilot.get_full_state() + if full_state.session_info is None: + return + with open(getSessionFilePath(session_id), "w") as f: json.dump(full_state.dict(), f) + # Read and update the sessions list + with open(getSessionsListFilePath(), "r") as f: + sessions_list = json.load(f) + + session_ids = [s["session_id"] for s in sessions_list] + if session_id not in session_ids: + sessions_list.append(full_state.session_info.dict()) + + with open(getSessionsListFilePath(), "w") as f: + json.dump(sessions_list, f) + + async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: + """Load the session's FullState from a json file""" + + # First persist the current state + await self.persist_session(old_session_id) + + # Delete the old session, but keep the IDE + ide = self.registered_ides[old_session_id] + del self.registered_ides[old_session_id] + + # Start the new session + new_session = await self.new_session(ide, session_id=new_session_id) + return new_session.session_id + def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws logger.debug(f"Registered websocket for session {session_id}") @@ -130,3 +149,15 @@ class SessionManager: session_manager = SessionManager() + + +@router.get("/list") +async def list_sessions(): + """List all sessions""" + sessions_list_file = getSessionsListFilePath() + if not os.path.exists(sessions_list_file): + print("Returning empty sessions list") + return [] + sessions = json.load(open(sessions_list_file, "r")) + print("Returning sessions list: ", sessions) + return sessions |