diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-06 15:39:16 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-06 15:39:16 -0700 |
commit | 19060a30faf94454f4d69d01828a33985d07f109 (patch) | |
tree | 10e983b351b39e51cc054e280074c65b54ac2c62 /continuedev/src | |
parent | c25527926ad1d1f861dbed01df577e962e08d746 (diff) | |
download | sncontinue-19060a30faf94454f4d69d01828a33985d07f109.tar.gz sncontinue-19060a30faf94454f4d69d01828a33985d07f109.tar.bz2 sncontinue-19060a30faf94454f4d69d01828a33985d07f109.zip |
feat: :construction: create new sessions
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 13 |
3 files changed, 14 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 6dd30db1..ee29dc88 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -89,6 +89,7 @@ class Autopilot(ContinueBaseModel): if full_state is not None: self.history = full_state.history self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code + self.session_info = full_state.session_info self.started = True diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 661e1787..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from pydantic import BaseModel import traceback from uvicorn.main import Server @@ -100,7 +100,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": - self.load_session(data["session_id"]) + self.load_session(data.get("session_id", None)) def on_main_input(self, input: str): # Do something with user input @@ -156,10 +156,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) - def load_session(self, session_id: str): + def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): - await session_manager.load_session(self.session.session_id, session_id) - await self._send_json("reconnect_at_session", {"session_id": session_id}) + new_session_id = await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": new_session_id}) create_async_task( load_and_tell_to_reconnect(), self.on_error) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 062f9527..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,7 +1,7 @@ import os import traceback from fastapi import WebSocket, APIRouter -from typing import Any, Coroutine, Dict, Union +from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 import json @@ -49,7 +49,7 @@ class SessionManager: raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -110,12 +110,14 @@ class SessionManager: with open(getSessionsListFilePath(), "r") as f: sessions_list = json.load(f) - sessions_list.append(full_state.session_info.dict()) + session_ids = [s["session_id"] for s in sessions_list] + if session_id not in session_ids: + sessions_list.append(full_state.session_info.dict()) with open(getSessionsListFilePath(), "w") as f: json.dump(sessions_list, f) - async def load_session(self, old_session_id: str, new_session_id: str): + async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: """Load the session's FullState from a json file""" # First persist the current state @@ -126,7 +128,8 @@ class SessionManager: del self.registered_ides[old_session_id] # Start the new session - await self.new_session(ide, session_id=new_session_id) + new_session = await self.new_session(ide, session_id=new_session_id) + return new_session.session_id def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws |