summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-06 15:39:16 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-06 15:39:16 -0700
commit19060a30faf94454f4d69d01828a33985d07f109 (patch)
tree10e983b351b39e51cc054e280074c65b54ac2c62 /continuedev/src
parentc25527926ad1d1f861dbed01df577e962e08d746 (diff)
downloadsncontinue-19060a30faf94454f4d69d01828a33985d07f109.tar.gz
sncontinue-19060a30faf94454f4d69d01828a33985d07f109.tar.bz2
sncontinue-19060a30faf94454f4d69d01828a33985d07f109.zip
feat: :construction: create new sessions
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/server/gui.py10
-rw-r--r--continuedev/src/continuedev/server/session_manager.py13
3 files changed, 14 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 6dd30db1..ee29dc88 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -89,6 +89,7 @@ class Autopilot(ContinueBaseModel):
if full_state is not None:
self.history = full_state.history
self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code
+ self.session_info = full_state.session_info
self.started = True
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 661e1787..4470999a 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -2,7 +2,7 @@ import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
-from typing import Any, List, Type, TypeVar
+from typing import Any, List, Optional, Type, TypeVar
from pydantic import BaseModel
import traceback
from uvicorn.main import Server
@@ -100,7 +100,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
elif message_type == "load_session":
- self.load_session(data["session_id"])
+ self.load_session(data.get("session_id", None))
def on_main_input(self, input: str):
# Do something with user input
@@ -156,10 +156,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(
self.session.autopilot.select_context_item(id, query), self.on_error)
- def load_session(self, session_id: str):
+ def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
- await session_manager.load_session(self.session.session_id, session_id)
- await self._send_json("reconnect_at_session", {"session_id": session_id})
+ new_session_id = await session_manager.load_session(self.session.session_id, session_id)
+ await self._send_json("reconnect_at_session", {"session_id": new_session_id})
create_async_task(
load_and_tell_to_reconnect(), self.on_error)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 062f9527..cde0344e 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,7 +1,7 @@
import os
import traceback
from fastapi import WebSocket, APIRouter
-from typing import Any, Coroutine, Dict, Union
+from typing import Any, Coroutine, Dict, Optional, Union
from uuid import uuid4
import json
@@ -49,7 +49,7 @@ class SessionManager:
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session:
logger.debug(f"New session: {session_id}")
# Load the persisted state (not being used right now)
@@ -110,12 +110,14 @@ class SessionManager:
with open(getSessionsListFilePath(), "r") as f:
sessions_list = json.load(f)
- sessions_list.append(full_state.session_info.dict())
+ session_ids = [s["session_id"] for s in sessions_list]
+ if session_id not in session_ids:
+ sessions_list.append(full_state.session_info.dict())
with open(getSessionsListFilePath(), "w") as f:
json.dump(sessions_list, f)
- async def load_session(self, old_session_id: str, new_session_id: str):
+ async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str:
"""Load the session's FullState from a json file"""
# First persist the current state
@@ -126,7 +128,8 @@ class SessionManager:
del self.registered_ides[old_session_id]
# Start the new session
- await self.new_session(ide, session_id=new_session_id)
+ new_session = await self.new_session(ide, session_id=new_session_id)
+ return new_session.session_id
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws