From f19345c652cfcf1bdf13d0a44a2f302e0cd1aa4c Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 6 Aug 2023 09:28:22 -0700 Subject: feat: :construction: Router and new history page --- .../src/hooks/AbstractContinueGUIClientProtocol.ts | 2 + .../src/hooks/ContinueGUIClientProtocol.ts | 63 ++++++++++++++++------ 2 files changed, 48 insertions(+), 17 deletions(-) (limited to 'extension/react-app/src/hooks') diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index 8d8b7b7e..168fb156 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -30,6 +30,8 @@ abstract class AbstractContinueGUIClientProtocol { abstract showLogsAtIndex(index: number): void; abstract selectContextItem(id: string, query: string): void; + + abstract onReconnectAtSession(session_id: string): void; } export default AbstractContinueGUIClientProtocol; diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index b6dd43d9..830954c5 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -4,15 +4,18 @@ import { Messenger, WebsocketMessenger } from "./messenger"; import { VscodeMessenger } from "./vscodeMessenger"; class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { - messenger: Messenger; + messenger?: Messenger; // Server URL must contain the session ID param serverUrlWithSessionId: string; + useVscodeMessagePassing: boolean; - constructor( + private connectMessenger( serverUrlWithSessionId: string, useVscodeMessagePassing: boolean ) { - super(); + if (this.messenger) { + // this.messenger.close(); TODO + } this.serverUrlWithSessionId = serverUrlWithSessionId; this.messenger = useVscodeMessagePassing ? new VscodeMessenger(serverUrlWithSessionId) @@ -24,26 +27,52 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { this.messenger.onError((error) => { console.log("GUI -> IDE websocket error", error); }); + + this.messenger.onMessageType("reconnect_at_session", (data: any) => { + if (data.session_id) { + this.onReconnectAtSession(data.session_id); + } + }); + } + + constructor( + serverUrlWithSessionId: string, + useVscodeMessagePassing: boolean + ) { + super(); + this.serverUrlWithSessionId = serverUrlWithSessionId; + this.useVscodeMessagePassing = useVscodeMessagePassing; + this.connectMessenger(serverUrlWithSessionId, useVscodeMessagePassing); + } + + onReconnectAtSession(session_id: string): void { + this.connectMessenger( + this.serverUrlWithSessionId.replace( + /\/session\/[a-zA-Z0-9-]+/, + `/session/${session_id}` + ), + this.useVscodeMessagePassing + ); } sendMainInput(input: string) { - this.messenger.send("main_input", { input }); + this.messenger?.send("main_input", { input }); } reverseToIndex(index: number) { - this.messenger.send("reverse_to_index", { index }); + this.messenger?.send("reverse_to_index", { index }); } sendRefinementInput(input: string, index: number) { - this.messenger.send("refinement_input", { input, index }); + this.messenger?.send("refinement_input", { input, index }); } sendStepUserInput(input: string, index: number) { - this.messenger.send("step_user_input", { input, index }); + this.messenger?.send("step_user_input", { input, index }); } onStateUpdate(callback: (state: any) => void) { - this.messenger.onMessageType("state_update", (data: any) => { + this.messenger?.onMessageType("state_update", (data: any) => { if (data.state) { callback(data.state); } @@ -53,7 +82,7 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { onAvailableSlashCommands( callback: (commands: { name: string; description: string }[]) => void ) { - this.messenger.onMessageType("available_slash_commands", (data: any) => { + this.messenger?.onMessageType("available_slash_commands", (data: any) => { if (data.commands) { callback(data.commands); } @@ -61,37 +90,37 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { } sendClear() { - this.messenger.send("clear_history", {}); + this.messenger?.send("clear_history", {}); } retryAtIndex(index: number) { - this.messenger.send("retry_at_index", { index }); + this.messenger?.send("retry_at_index", { index }); } deleteAtIndex(index: number) { - this.messenger.send("delete_at_index", { index }); + this.messenger?.send("delete_at_index", { index }); } deleteContextWithIds(ids: ContextItemId[]) { - this.messenger.send("delete_context_with_ids", { + this.messenger?.send("delete_context_with_ids", { ids: ids.map((id) => `${id.provider_title}-${id.item_id}`), }); } setEditingAtIds(ids: string[]) { - this.messenger.send("set_editing_at_ids", { ids }); + this.messenger?.send("set_editing_at_ids", { ids }); } toggleAddingHighlightedCode(): void { - this.messenger.send("toggle_adding_highlighted_code", {}); + this.messenger?.send("toggle_adding_highlighted_code", {}); } showLogsAtIndex(index: number): void { - this.messenger.send("show_logs_at_index", { index }); + this.messenger?.send("show_logs_at_index", { index }); } selectContextItem(id: string, query: string): void { - this.messenger.send("select_context_item", { id, query }); + this.messenger?.send("select_context_item", { id, query }); } } -- cgit v1.2.3-70-g09d2 From c25527926ad1d1f861dbed01df577e962e08d746 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 6 Aug 2023 15:24:13 -0700 Subject: feat: :construction: successfully loading past sessions --- continuedev/src/continuedev/core/autopilot.py | 21 ++++++++++++--- continuedev/src/continuedev/core/main.py | 15 ++++++----- continuedev/src/continuedev/libs/util/paths.py | 3 +++ continuedev/src/continuedev/libs/util/telemetry.py | 1 - continuedev/src/continuedev/server/gui.py | 11 ++++++-- .../src/continuedev/server/session_manager.py | 27 ++++++++----------- .../src/hooks/AbstractContinueGUIClientProtocol.ts | 2 ++ .../src/hooks/ContinueGUIClientProtocol.ts | 13 ++++++---- extension/react-app/src/hooks/messenger.ts | 6 +++++ extension/react-app/src/hooks/vscodeMessenger.ts | 4 +++ extension/react-app/src/pages/history.tsx | 30 +++++++++++++++++----- .../react-app/src/redux/slices/configSlice.ts | 4 +-- extension/src/debugPanel.ts | 9 +++++++ extension/src/util/messenger.ts | 6 +++++ 14 files changed, 108 insertions(+), 44 deletions(-) (limited to 'extension/react-app/src/hooks') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 7e7ce5d8..6dd30db1 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -15,7 +15,7 @@ from ..plugins.context_providers.highlighted_code import HighlightedCodeContextP from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode +from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep from .sdk import ContinueSDK from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback @@ -53,7 +53,8 @@ class Autopilot(ContinueBaseModel): policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() - full_state: Union[FullState, None] = None + full_state: Optional[FullState] = None + session_info: Optional[SessionInfo] = None context_manager: ContextManager = ContextManager() continue_sdk: ContinueSDK = None @@ -88,7 +89,6 @@ class Autopilot(ContinueBaseModel): if full_state is not None: self.history = full_state.history self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code - await self.context_manager.set_selected_items(full_state.selected_context_items) self.started = True @@ -112,6 +112,7 @@ class Autopilot(ContinueBaseModel): adding_highlighted_code=self.context_manager.context_providers[ "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False, selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], + session_info=self.session_info ) self.full_state = full_state return full_state @@ -375,6 +376,20 @@ class Autopilot(ContinueBaseModel): self._main_user_input_queue.append(user_input) await self.update_subscribers() + # Use the first input to create title for session info, and make the session saveable + if self.session_info is None: + async def create_title(): + title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ") + self.session_info = SessionInfo( + title=title, + session_id=self.ide.session_id, + date_created=time.strftime( + "%Y-%m-%d %H:%M:%S", time.gmtime()) + ) + + create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep(e=e))) + if len(self._main_user_input_queue) > 1: return diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 9a06f2e1..a33d777e 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,5 +1,5 @@ import json -from typing import Coroutine, Dict, List, Literal, Union +from typing import Coroutine, Dict, List, Literal, Optional, Union from pydantic.schema import schema @@ -253,6 +253,12 @@ class ContextItem(BaseModel): editable: bool = False +class SessionInfo(ContinueBaseModel): + session_id: str + title: str + date_created: str + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" history: History @@ -261,12 +267,7 @@ class FullState(ContinueBaseModel): slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool selected_context_items: List[ContextItem] - - -class SessionInfo(ContinueBaseModel): - session_id: str - title: str - date_created: str + session_info: Optional[SessionInfo] = None class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 66c921f7..01b594cf 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -35,6 +35,9 @@ def getSessionFilePath(session_id: str): def getSessionsListFilePath(): path = os.path.join(getSessionsFolderPath(), "sessions.json") os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, 'w') as f: + f.write("[]") return path diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index 60c910bb..0f66ad8d 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -23,7 +23,6 @@ class PostHogLogger: self.posthog = Posthog(self.api_key, host='https://app.posthog.com') def setup(self, unique_id: str, allow_anonymous_telemetry: bool): - logger.debug(f"Setting unique_id as {unique_id}") self.unique_id = unique_id or "NO_UNIQUE_ID" self.allow_anonymous_telemetry = allow_anonymous_telemetry or True diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index b6f7b141..661e1787 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_show_logs_at_index(data["index"]) elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) + elif message_type == "load_session": + self.load_session(data["session_id"]) def on_main_input(self, input: str): # Do something with user input @@ -154,8 +156,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) - async def reconnect_at_session(self, session_id: str): - await self._send_json("reconnect_at_session", {"session_id": session_id}) + def load_session(self, session_id: str): + async def load_and_tell_to_reconnect(): + await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": session_id}) + + create_async_task( + load_and_tell_to_reconnect(), self.on_error) @router.websocket("/ws") diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index f876c9a9..062f9527 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -78,18 +78,7 @@ class SessionManager: try: await autopilot.start(full_state=full_state) except Exception as e: - # Have to manually add to history because autopilot isn't started - formatted_err = '\n'.join(traceback.format_exception(e)) - msg_step = MessageStep( - name="Error starting Autopilot", message=formatted_err) - msg_step.description = f"```\n{formatted_err}\n```" - autopilot.history.add_node(HistoryNode( - step=msg_step, - observation=None, - depth=0, - active=False - )) - logger.warning(f"Error starting Autopilot: {e}") + await self.on_error(e) def on_error(e: Exception) -> Coroutine: err_msg = '\n'.join(traceback.format_exception(e)) @@ -101,7 +90,7 @@ class SessionManager: async def remove_session(self, session_id: str): logger.debug(f"Removing session: {session_id}") if session_id in self.sessions: - if session_id in self.registered_ides: + if session_id in self.registered_ides and self.registered_ides[session_id] is not None: ws_to_close = self.registered_ides[session_id].websocket if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: await self.sessions[session_id].autopilot.ide.websocket.close() @@ -111,6 +100,9 @@ class SessionManager: async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" full_state = await self.sessions[session_id].autopilot.get_full_state() + if full_state.session_info is None: + return + with open(getSessionFilePath(session_id), "w") as f: json.dump(full_state.dict(), f) @@ -118,9 +110,10 @@ class SessionManager: with open(getSessionsListFilePath(), "r") as f: sessions_list = json.load(f) - sessions_list.append(SessionInfo( - session_info=full_state.session_info - )) + sessions_list.append(full_state.session_info.dict()) + + with open(getSessionsListFilePath(), "w") as f: + json.dump(sessions_list, f) async def load_session(self, old_session_id: str, new_session_id: str): """Load the session's FullState from a json file""" @@ -130,7 +123,7 @@ class SessionManager: # Delete the old session, but keep the IDE ide = self.registered_ides[old_session_id] - self.registered_ides[old_session_id] = None + del self.registered_ides[old_session_id] # Start the new session await self.new_session(ide, session_id=new_session_id) diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index 168fb156..139c9d05 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -31,6 +31,8 @@ abstract class AbstractContinueGUIClientProtocol { abstract selectContextItem(id: string, query: string): void; + abstract loadSession(session_id: string): void; + abstract onReconnectAtSession(session_id: string): void; } diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index 830954c5..6cfbf66a 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -14,9 +14,11 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { useVscodeMessagePassing: boolean ) { if (this.messenger) { - // this.messenger.close(); TODO + console.log("Closing session: ", this.serverUrlWithSessionId); + this.messenger.close(); } this.serverUrlWithSessionId = serverUrlWithSessionId; + this.useVscodeMessagePassing = useVscodeMessagePassing; this.messenger = useVscodeMessagePassing ? new VscodeMessenger(serverUrlWithSessionId) : new WebsocketMessenger(serverUrlWithSessionId); @@ -45,12 +47,13 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { this.connectMessenger(serverUrlWithSessionId, useVscodeMessagePassing); } + loadSession(session_id: string): void { + this.messenger?.send("load_session", { session_id }); + } + onReconnectAtSession(session_id: string): void { this.connectMessenger( - this.serverUrlWithSessionId.replace( - /\/session\/[a-zA-Z0-9-]+/, - `/session/${session_id}` - ), + `${this.serverUrlWithSessionId.split("?")[0]}?session_id=${session_id}`, this.useVscodeMessagePassing ); } diff --git a/extension/react-app/src/hooks/messenger.ts b/extension/react-app/src/hooks/messenger.ts index ecf646c7..0bfbe00c 100644 --- a/extension/react-app/src/hooks/messenger.ts +++ b/extension/react-app/src/hooks/messenger.ts @@ -15,6 +15,8 @@ export abstract class Messenger { abstract sendAndReceive(messageType: string, data: any): Promise; abstract onError(callback: (error: any) => void): void; + + abstract close(): void; } export class WebsocketMessenger extends Messenger { @@ -105,4 +107,8 @@ export class WebsocketMessenger extends Messenger { onError(callback: (error: any) => void): void { this.websocket.addEventListener("error", callback); } + + close(): void { + this.websocket.close(); + } } diff --git a/extension/react-app/src/hooks/vscodeMessenger.ts b/extension/react-app/src/hooks/vscodeMessenger.ts index 13f5092b..cf626721 100644 --- a/extension/react-app/src/hooks/vscodeMessenger.ts +++ b/extension/react-app/src/hooks/vscodeMessenger.ts @@ -76,4 +76,8 @@ export class VscodeMessenger extends Messenger { } }); } + + close(): void { + postVscMessage("websocketForwardingClose", { url: this.serverUrl }); + } } diff --git a/extension/react-app/src/pages/history.tsx b/extension/react-app/src/pages/history.tsx index 052fe5be..0142836f 100644 --- a/extension/react-app/src/pages/history.tsx +++ b/extension/react-app/src/pages/history.tsx @@ -4,6 +4,23 @@ import { GUIClientContext } from "../App"; import { useSelector } from "react-redux"; import { RootStore } from "../redux/store"; import { useNavigate } from "react-router-dom"; +import { secondaryDark, vscBackground } from "../components"; +import styled from "styled-components"; + +const Tr = styled.tr` + &:hover { + background-color: ${secondaryDark}; + } +`; + +const TdDiv = styled.div` + cursor: pointer; + padding-left: 1rem; + padding-right: 1rem; + padding-top: 0.5rem; + padding-bottom: 0.5rem; + border-bottom: 1px solid ${secondaryDark}; +`; function History() { const navigate = useNavigate(); @@ -30,24 +47,23 @@ function History() { return (
-

History

+

History

{sessions.map((session, index) => ( - + - + ))}
-
{ - // client?.loadSession(session.id); + client?.loadSession(session.session_id); navigate("/"); }} >
{session.title}
{session.date_created}
-
+
diff --git a/extension/react-app/src/redux/slices/configSlice.ts b/extension/react-app/src/redux/slices/configSlice.ts index 57c4f860..59c76066 100644 --- a/extension/react-app/src/redux/slices/configSlice.ts +++ b/extension/react-app/src/redux/slices/configSlice.ts @@ -50,7 +50,7 @@ export const configSlice = createSlice({ ) => ({ ...state, dataSwitchOn: action.payload, - }) + }), }, }); @@ -60,6 +60,6 @@ export const { setWorkspacePath, setSessionId, setVscMediaUrl, - setDataSwitchOn + setDataSwitchOn, } = configSlice.actions; export default configSlice.reducer; diff --git a/extension/src/debugPanel.ts b/extension/src/debugPanel.ts index b687c3e4..d133080b 100644 --- a/extension/src/debugPanel.ts +++ b/extension/src/debugPanel.ts @@ -221,6 +221,15 @@ export function setupDebugPanel( } break; } + case "websocketForwardingClose": { + let url = data.url; + let connection = websocketConnections[url]; + if (typeof connection !== "undefined") { + connection.close(); + websocketConnections[url] = undefined; + } + break; + } case "websocketForwardingMessage": { let url = data.url; let connection = websocketConnections[url]; diff --git a/extension/src/util/messenger.ts b/extension/src/util/messenger.ts index bcc88fe1..152d4a1f 100644 --- a/extension/src/util/messenger.ts +++ b/extension/src/util/messenger.ts @@ -18,6 +18,8 @@ export abstract class Messenger { abstract onError(callback: () => void): void; abstract sendAndReceive(messageType: string, data: any): Promise; + + abstract close(): void; } export class WebsocketMessenger extends Messenger { @@ -160,4 +162,8 @@ export class WebsocketMessenger extends Messenger { onError(callback: () => void): void { this.websocket.addEventListener("error", callback); } + + close(): void { + this.websocket.close(); + } } -- cgit v1.2.3-70-g09d2 From 19060a30faf94454f4d69d01828a33985d07f109 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 6 Aug 2023 15:39:16 -0700 Subject: feat: :construction: create new sessions --- continuedev/src/continuedev/core/autopilot.py | 1 + continuedev/src/continuedev/server/gui.py | 10 +++++----- continuedev/src/continuedev/server/session_manager.py | 13 ++++++++----- .../src/hooks/AbstractContinueGUIClientProtocol.ts | 2 +- extension/react-app/src/hooks/ContinueGUIClientProtocol.ts | 2 +- extension/react-app/src/pages/gui.tsx | 8 ++++---- 6 files changed, 20 insertions(+), 16 deletions(-) (limited to 'extension/react-app/src/hooks') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 6dd30db1..ee29dc88 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -89,6 +89,7 @@ class Autopilot(ContinueBaseModel): if full_state is not None: self.history = full_state.history self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code + self.session_info = full_state.session_info self.started = True diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 661e1787..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from pydantic import BaseModel import traceback from uvicorn.main import Server @@ -100,7 +100,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": - self.load_session(data["session_id"]) + self.load_session(data.get("session_id", None)) def on_main_input(self, input: str): # Do something with user input @@ -156,10 +156,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) - def load_session(self, session_id: str): + def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): - await session_manager.load_session(self.session.session_id, session_id) - await self._send_json("reconnect_at_session", {"session_id": session_id}) + new_session_id = await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": new_session_id}) create_async_task( load_and_tell_to_reconnect(), self.on_error) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 062f9527..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,7 +1,7 @@ import os import traceback from fastapi import WebSocket, APIRouter -from typing import Any, Coroutine, Dict, Union +from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 import json @@ -49,7 +49,7 @@ class SessionManager: raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -110,12 +110,14 @@ class SessionManager: with open(getSessionsListFilePath(), "r") as f: sessions_list = json.load(f) - sessions_list.append(full_state.session_info.dict()) + session_ids = [s["session_id"] for s in sessions_list] + if session_id not in session_ids: + sessions_list.append(full_state.session_info.dict()) with open(getSessionsListFilePath(), "w") as f: json.dump(sessions_list, f) - async def load_session(self, old_session_id: str, new_session_id: str): + async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: """Load the session's FullState from a json file""" # First persist the current state @@ -126,7 +128,8 @@ class SessionManager: del self.registered_ides[old_session_id] # Start the new session - await self.new_session(ide, session_id=new_session_id) + new_session = await self.new_session(ide, session_id=new_session_id) + return new_session.session_id def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index 139c9d05..e018c03c 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -31,7 +31,7 @@ abstract class AbstractContinueGUIClientProtocol { abstract selectContextItem(id: string, query: string): void; - abstract loadSession(session_id: string): void; + abstract loadSession(session_id?: string): void; abstract onReconnectAtSession(session_id: string): void; } diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index 6cfbf66a..c2285f6d 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -47,7 +47,7 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { this.connectMessenger(serverUrlWithSessionId, useVscodeMessagePassing); } - loadSession(session_id: string): void { + loadSession(session_id?: string): void { this.messenger?.send("load_session", { session_id }); } diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index d565e64f..dab429b5 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -16,7 +16,7 @@ import { BookOpenIcon, ChatBubbleOvalLeftEllipsisIcon, TrashIcon, - PlusCircleIcon, + PlusIcon, FolderIcon, } from "@heroicons/react/24/outline"; import ComboBox from "../components/ComboBox"; @@ -589,11 +589,11 @@ If you already have an LLM deployed on your own infrastructure, or would like to { - client?.sendClear(); + client?.loadSession(undefined); }} - text="Clear" + text="New Session" > - + { -- cgit v1.2.3-70-g09d2