diff options
6 files changed, 20 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 6dd30db1..ee29dc88 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -89,6 +89,7 @@ class Autopilot(ContinueBaseModel): if full_state is not None: self.history = full_state.history self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code + self.session_info = full_state.session_info self.started = True diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 661e1787..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from pydantic import BaseModel import traceback from uvicorn.main import Server @@ -100,7 +100,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": - self.load_session(data["session_id"]) + self.load_session(data.get("session_id", None)) def on_main_input(self, input: str): # Do something with user input @@ -156,10 +156,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) - def load_session(self, session_id: str): + def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): - await session_manager.load_session(self.session.session_id, session_id) - await self._send_json("reconnect_at_session", {"session_id": session_id}) + new_session_id = await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": new_session_id}) create_async_task( load_and_tell_to_reconnect(), self.on_error) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 062f9527..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,7 +1,7 @@ import os import traceback from fastapi import WebSocket, APIRouter -from typing import Any, Coroutine, Dict, Union +from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 import json @@ -49,7 +49,7 @@ class SessionManager: raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -110,12 +110,14 @@ class SessionManager: with open(getSessionsListFilePath(), "r") as f: sessions_list = json.load(f) - sessions_list.append(full_state.session_info.dict()) + session_ids = [s["session_id"] for s in sessions_list] + if session_id not in session_ids: + sessions_list.append(full_state.session_info.dict()) with open(getSessionsListFilePath(), "w") as f: json.dump(sessions_list, f) - async def load_session(self, old_session_id: str, new_session_id: str): + async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: """Load the session's FullState from a json file""" # First persist the current state @@ -126,7 +128,8 @@ class SessionManager: del self.registered_ides[old_session_id] # Start the new session - await self.new_session(ide, session_id=new_session_id) + new_session = await self.new_session(ide, session_id=new_session_id) + return new_session.session_id def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index 139c9d05..e018c03c 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -31,7 +31,7 @@ abstract class AbstractContinueGUIClientProtocol { abstract selectContextItem(id: string, query: string): void; - abstract loadSession(session_id: string): void; + abstract loadSession(session_id?: string): void; abstract onReconnectAtSession(session_id: string): void; } diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index 6cfbf66a..c2285f6d 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -47,7 +47,7 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { this.connectMessenger(serverUrlWithSessionId, useVscodeMessagePassing); } - loadSession(session_id: string): void { + loadSession(session_id?: string): void { this.messenger?.send("load_session", { session_id }); } diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index d565e64f..dab429b5 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -16,7 +16,7 @@ import { BookOpenIcon, ChatBubbleOvalLeftEllipsisIcon, TrashIcon, - PlusCircleIcon, + PlusIcon, FolderIcon, } from "@heroicons/react/24/outline"; import ComboBox from "../components/ComboBox"; @@ -589,11 +589,11 @@ If you already have an LLM deployed on your own infrastructure, or would like to </HeaderButtonWithText> <HeaderButtonWithText onClick={() => { - client?.sendClear(); + client?.loadSession(undefined); }} - text="Clear" + text="New Session" > - <PlusCircleIcon width="1.4em" height="1.4em" /> + <PlusIcon width="1.4em" height="1.4em" /> </HeaderButtonWithText> <HeaderButtonWithText onClick={() => { |