summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py31
-rw-r--r--continuedev/src/continuedev/core/main.py9
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py9
-rw-r--r--continuedev/src/continuedev/libs/util/telemetry.py1
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py4
-rw-r--r--continuedev/src/continuedev/server/gui.py12
-rw-r--r--continuedev/src/continuedev/server/main.py3
-rw-r--r--continuedev/src/continuedev/server/session_manager.py69
8 files changed, 109 insertions, 29 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index f3a17d47..256f3439 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,10 +1,11 @@
from functools import cached_property
import traceback
import time
-from typing import Callable, Coroutine, Dict, List, Union
+from typing import Callable, Coroutine, Dict, List, Optional, Union
from aiohttp import ClientPayloadError
from pydantic import root_validator
+from ..libs.util.strings import remove_quotes_and_escapes
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
@@ -15,7 +16,7 @@ from ..plugins.context_providers.highlighted_code import HighlightedCodeContextP
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..models.main import ContinueBaseModel
-from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode
+from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode
from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep
from .sdk import ContinueSDK
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
@@ -53,7 +54,8 @@ class Autopilot(ContinueBaseModel):
policy: Policy = DefaultPolicy()
history: History = History.from_empty()
context: Context = Context()
- full_state: Union[FullState, None] = None
+ full_state: Optional[FullState] = None
+ session_info: Optional[SessionInfo] = None
context_manager: ContextManager = ContextManager()
continue_sdk: ContinueSDK = None
@@ -68,7 +70,7 @@ class Autopilot(ContinueBaseModel):
started: bool = False
- async def start(self):
+ async def start(self, full_state: Optional[FullState] = None):
self.continue_sdk = await ContinueSDK.create(self)
if override_policy := self.continue_sdk.config.policy_override:
self.policy = override_policy
@@ -84,6 +86,12 @@ class Autopilot(ContinueBaseModel):
logger.debug("Loading index")
create_async_task(self.context_manager.load_index(
self.ide.workspace_directory))
+
+ if full_state is not None:
+ self.history = full_state.history
+ self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code
+ self.session_info = full_state.session_info
+
self.started = True
class Config:
@@ -106,6 +114,7 @@ class Autopilot(ContinueBaseModel):
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,
selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],
+ session_info=self.session_info
)
self.full_state = full_state
return full_state
@@ -369,6 +378,20 @@ class Autopilot(ContinueBaseModel):
self._main_user_input_queue.append(user_input)
await self.update_subscribers()
+ # Use the first input to create title for session info, and make the session saveable
+ if self.session_info is None:
+ async def create_title():
+ title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ")
+ title = remove_quotes_and_escapes(title)
+ self.session_info = SessionInfo(
+ title=title,
+ session_id=self.ide.session_id,
+ date_created=str(time.time())
+ )
+
+ create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step(
+ DisplayErrorStep(e=e)))
+
if len(self._main_user_input_queue) > 1:
return
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 2553850f..a33d777e 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -1,5 +1,5 @@
import json
-from typing import Coroutine, Dict, List, Literal, Union
+from typing import Coroutine, Dict, List, Literal, Optional, Union
from pydantic.schema import schema
@@ -253,6 +253,12 @@ class ContextItem(BaseModel):
editable: bool = False
+class SessionInfo(ContinueBaseModel):
+ session_id: str
+ title: str
+ date_created: str
+
+
class FullState(ContinueBaseModel):
"""A full state of the program, including the history"""
history: History
@@ -261,6 +267,7 @@ class FullState(ContinueBaseModel):
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
+ session_info: Optional[SessionInfo] = None
class ContinueSDK:
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index 83a472ad..01b594cf 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -32,6 +32,15 @@ def getSessionFilePath(session_id: str):
return path
+def getSessionsListFilePath():
+ path = os.path.join(getSessionsFolderPath(), "sessions.json")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ if not os.path.exists(path):
+ with open(path, 'w') as f:
+ f.write("[]")
+ return path
+
+
def getConfigFilePath() -> str:
path = os.path.join(getGlobalFolderPath(), "config.py")
os.makedirs(os.path.dirname(path), exist_ok=True)
diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py
index 60c910bb..0f66ad8d 100644
--- a/continuedev/src/continuedev/libs/util/telemetry.py
+++ b/continuedev/src/continuedev/libs/util/telemetry.py
@@ -23,7 +23,6 @@ class PostHogLogger:
self.posthog = Posthog(self.api_key, host='https://app.posthog.com')
def setup(self, unique_id: str, allow_anonymous_telemetry: bool):
- logger.debug(f"Setting unique_id as {unique_id}")
self.unique_id = unique_id or "NO_UNIQUE_ID"
self.allow_anonymous_telemetry = allow_anonymous_telemetry or True
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 2166bc37..4262ac55 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -1,7 +1,7 @@
from .main import *
from .filesystem import RangeInFile, FileEdit
from .filesystem_edit import FileEditWithFullContents
-from ..core.main import History, HistoryNode, FullState
+from ..core.main import History, HistoryNode, FullState, SessionInfo
from ..core.context import ContextItem
from pydantic import schema_json_of
import os
@@ -13,7 +13,7 @@ MODELS_TO_GENERATE = [
] + [
FileEditWithFullContents
] + [
- History, HistoryNode, FullState
+ History, HistoryNode, FullState, SessionInfo
] + [
ContextItem
]
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 7c89c5c2..4470999a 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -2,7 +2,7 @@ import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
from starlette.websockets import WebSocketState, WebSocketDisconnect
-from typing import Any, List, Type, TypeVar
+from typing import Any, List, Optional, Type, TypeVar
from pydantic import BaseModel
import traceback
from uvicorn.main import Server
@@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_show_logs_at_index(data["index"])
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
+ elif message_type == "load_session":
+ self.load_session(data.get("session_id", None))
def on_main_input(self, input: str):
# Do something with user input
@@ -154,6 +156,14 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(
self.session.autopilot.select_context_item(id, query), self.on_error)
+ def load_session(self, session_id: Optional[str] = None):
+ async def load_and_tell_to_reconnect():
+ new_session_id = await session_manager.load_session(self.session.session_id, session_id)
+ await self._send_json("reconnect_at_session", {"session_id": new_session_id})
+
+ create_async_task(
+ load_and_tell_to_reconnect(), self.on_error)
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index f8dfb009..f0a3f094 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -11,13 +11,14 @@ import argparse
from .ide import router as ide_router
from .gui import router as gui_router
-from .session_manager import session_manager
+from .session_manager import session_manager, router as sessions_router
from ..libs.util.logging import logger
app = FastAPI()
app.include_router(ide_router)
app.include_router(gui_router)
+app.include_router(sessions_router)
# Add CORS support
app.add_middleware(
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 56c92307..cde0344e 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,21 +1,23 @@
import os
import traceback
-from fastapi import WebSocket
-from typing import Any, Coroutine, Dict, Union
+from fastapi import WebSocket, APIRouter
+from typing import Any, Coroutine, Dict, Optional, Union
from uuid import uuid4
import json
from fastapi.websockets import WebSocketState
from ..plugins.steps.core.core import MessageStep
-from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
-from ..core.main import FullState, HistoryNode
+from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath
+from ..core.main import FullState, HistoryNode, SessionInfo
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
from ..libs.util.errors import SessionNotFound
from ..libs.util.logging import logger
+router = APIRouter(prefix="/sessions", tags=["sessions"])
+
class Session:
session_id: str
@@ -47,7 +49,7 @@ class SessionManager:
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session:
logger.debug(f"New session: {session_id}")
# Load the persisted state (not being used right now)
@@ -74,20 +76,9 @@ class SessionManager:
# Start the autopilot (must be after session is added to sessions) and the policy
try:
- await autopilot.start()
+ await autopilot.start(full_state=full_state)
except Exception as e:
- # Have to manually add to history because autopilot isn't started
- formatted_err = '\n'.join(traceback.format_exception(e))
- msg_step = MessageStep(
- name="Error loading context manager", message=formatted_err)
- msg_step.description = f"```\n{formatted_err}\n```"
- autopilot.history.add_node(HistoryNode(
- step=msg_step,
- observation=None,
- depth=0,
- active=False
- ))
- logger.warning(f"Error loading context manager: {e}")
+ await self.on_error(e)
def on_error(e: Exception) -> Coroutine:
err_msg = '\n'.join(traceback.format_exception(e))
@@ -99,7 +90,7 @@ class SessionManager:
async def remove_session(self, session_id: str):
logger.debug(f"Removing session: {session_id}")
if session_id in self.sessions:
- if session_id in self.registered_ides:
+ if session_id in self.registered_ides and self.registered_ides[session_id] is not None:
ws_to_close = self.registered_ides[session_id].websocket
if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED:
await self.sessions[session_id].autopilot.ide.websocket.close()
@@ -109,9 +100,37 @@ class SessionManager:
async def persist_session(self, session_id: str):
"""Save the session's FullState as a json file"""
full_state = await self.sessions[session_id].autopilot.get_full_state()
+ if full_state.session_info is None:
+ return
+
with open(getSessionFilePath(session_id), "w") as f:
json.dump(full_state.dict(), f)
+ # Read and update the sessions list
+ with open(getSessionsListFilePath(), "r") as f:
+ sessions_list = json.load(f)
+
+ session_ids = [s["session_id"] for s in sessions_list]
+ if session_id not in session_ids:
+ sessions_list.append(full_state.session_info.dict())
+
+ with open(getSessionsListFilePath(), "w") as f:
+ json.dump(sessions_list, f)
+
+ async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str:
+ """Load the session's FullState from a json file"""
+
+ # First persist the current state
+ await self.persist_session(old_session_id)
+
+ # Delete the old session, but keep the IDE
+ ide = self.registered_ides[old_session_id]
+ del self.registered_ides[old_session_id]
+
+ # Start the new session
+ new_session = await self.new_session(ide, session_id=new_session_id)
+ return new_session.session_id
+
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
logger.debug(f"Registered websocket for session {session_id}")
@@ -130,3 +149,15 @@ class SessionManager:
session_manager = SessionManager()
+
+
+@router.get("/list")
+async def list_sessions():
+ """List all sessions"""
+ sessions_list_file = getSessionsListFilePath()
+ if not os.path.exists(sessions_list_file):
+ print("Returning empty sessions list")
+ return []
+ sessions = json.load(open(sessions_list_file, "r"))
+ print("Returning sessions list: ", sessions)
+ return sessions