summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/server/gui.py10
-rw-r--r--continuedev/src/continuedev/server/session_manager.py13
-rw-r--r--extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts2
-rw-r--r--extension/react-app/src/hooks/ContinueGUIClientProtocol.ts2
-rw-r--r--extension/react-app/src/pages/gui.tsx8
6 files changed, 20 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 6dd30db1..ee29dc88 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -89,6 +89,7 @@ class Autopilot(ContinueBaseModel):
if full_state is not None:
self.history = full_state.history
self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code
+ self.session_info = full_state.session_info
self.started = True
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 661e1787..4470999a 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -2,7 +2,7 @@ import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
-from typing import Any, List, Type, TypeVar
+from typing import Any, List, Optional, Type, TypeVar
from pydantic import BaseModel
import traceback
from uvicorn.main import Server
@@ -100,7 +100,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
elif message_type == "load_session":
- self.load_session(data["session_id"])
+ self.load_session(data.get("session_id", None))
def on_main_input(self, input: str):
# Do something with user input
@@ -156,10 +156,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(
self.session.autopilot.select_context_item(id, query), self.on_error)
- def load_session(self, session_id: str):
+ def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
- await session_manager.load_session(self.session.session_id, session_id)
- await self._send_json("reconnect_at_session", {"session_id": session_id})
+ new_session_id = await session_manager.load_session(self.session.session_id, session_id)
+ await self._send_json("reconnect_at_session", {"session_id": new_session_id})
create_async_task(
load_and_tell_to_reconnect(), self.on_error)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 062f9527..cde0344e 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,7 +1,7 @@
import os
import traceback
from fastapi import WebSocket, APIRouter
-from typing import Any, Coroutine, Dict, Union
+from typing import Any, Coroutine, Dict, Optional, Union
from uuid import uuid4
import json
@@ -49,7 +49,7 @@ class SessionManager:
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session:
logger.debug(f"New session: {session_id}")
# Load the persisted state (not being used right now)
@@ -110,12 +110,14 @@ class SessionManager:
with open(getSessionsListFilePath(), "r") as f:
sessions_list = json.load(f)
- sessions_list.append(full_state.session_info.dict())
+ session_ids = [s["session_id"] for s in sessions_list]
+ if session_id not in session_ids:
+ sessions_list.append(full_state.session_info.dict())
with open(getSessionsListFilePath(), "w") as f:
json.dump(sessions_list, f)
- async def load_session(self, old_session_id: str, new_session_id: str):
+ async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str:
"""Load the session's FullState from a json file"""
# First persist the current state
@@ -126,7 +128,8 @@ class SessionManager:
del self.registered_ides[old_session_id]
# Start the new session
- await self.new_session(ide, session_id=new_session_id)
+ new_session = await self.new_session(ide, session_id=new_session_id)
+ return new_session.session_id
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
index 139c9d05..e018c03c 100644
--- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
+++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
@@ -31,7 +31,7 @@ abstract class AbstractContinueGUIClientProtocol {
abstract selectContextItem(id: string, query: string): void;
- abstract loadSession(session_id: string): void;
+ abstract loadSession(session_id?: string): void;
abstract onReconnectAtSession(session_id: string): void;
}
diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
index 6cfbf66a..c2285f6d 100644
--- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
+++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
@@ -47,7 +47,7 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol {
this.connectMessenger(serverUrlWithSessionId, useVscodeMessagePassing);
}
- loadSession(session_id: string): void {
+ loadSession(session_id?: string): void {
this.messenger?.send("load_session", { session_id });
}
diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx
index d565e64f..dab429b5 100644
--- a/extension/react-app/src/pages/gui.tsx
+++ b/extension/react-app/src/pages/gui.tsx
@@ -16,7 +16,7 @@ import {
BookOpenIcon,
ChatBubbleOvalLeftEllipsisIcon,
TrashIcon,
- PlusCircleIcon,
+ PlusIcon,
FolderIcon,
} from "@heroicons/react/24/outline";
import ComboBox from "../components/ComboBox";
@@ -589,11 +589,11 @@ If you already have an LLM deployed on your own infrastructure, or would like to
</HeaderButtonWithText>
<HeaderButtonWithText
onClick={() => {
- client?.sendClear();
+ client?.loadSession(undefined);
}}
- text="Clear"
+ text="New Session"
>
- <PlusCircleIcon width="1.4em" height="1.4em" />
+ <PlusIcon width="1.4em" height="1.4em" />
</HeaderButtonWithText>
<HeaderButtonWithText
onClick={() => {