diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 31 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/telemetry.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/generate_json_schema.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 12 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/main.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 69 | 
8 files changed, 109 insertions, 29 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index f3a17d47..256f3439 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,10 +1,11 @@  from functools import cached_property  import traceback  import time -from typing import Callable, Coroutine, Dict, List, Union +from typing import Callable, Coroutine, Dict, List, Optional, Union  from aiohttp import ClientPayloadError  from pydantic import root_validator +from ..libs.util.strings import remove_quotes_and_escapes  from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents  from .observation import Observation, InternalErrorObservation @@ -15,7 +16,7 @@ from ..plugins.context_providers.highlighted_code import HighlightedCodeContextP  from ..server.ide_protocol import AbstractIdeProtocolServer  from ..libs.util.queue import AsyncSubscriptionQueue  from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode +from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode  from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep  from .sdk import ContinueSDK  from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback @@ -53,7 +54,8 @@ class Autopilot(ContinueBaseModel):      policy: Policy = DefaultPolicy()      history: History = History.from_empty()      context: Context = Context() -    full_state: Union[FullState, None] = None +    full_state: Optional[FullState] = None +    session_info: Optional[SessionInfo] = None      context_manager: ContextManager = ContextManager()      continue_sdk: ContinueSDK = None @@ -68,7 +70,7 @@ class Autopilot(ContinueBaseModel):      started: bool = False -    async def start(self): +    async def start(self, full_state: Optional[FullState] = None):          self.continue_sdk = await ContinueSDK.create(self)          if override_policy := self.continue_sdk.config.policy_override:              self.policy = override_policy @@ -84,6 +86,12 @@ class Autopilot(ContinueBaseModel):          logger.debug("Loading index")          create_async_task(self.context_manager.load_index(              self.ide.workspace_directory)) + +        if full_state is not None: +            self.history = full_state.history +            self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code +            self.session_info = full_state.session_info +          self.started = True      class Config: @@ -106,6 +114,7 @@ class Autopilot(ContinueBaseModel):              adding_highlighted_code=self.context_manager.context_providers[                  "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,              selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], +            session_info=self.session_info          )          self.full_state = full_state          return full_state @@ -369,6 +378,20 @@ class Autopilot(ContinueBaseModel):          self._main_user_input_queue.append(user_input)          await self.update_subscribers() +        # Use the first input to create title for session info, and make the session saveable +        if self.session_info is None: +            async def create_title(): +                title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ") +                title = remove_quotes_and_escapes(title) +                self.session_info = SessionInfo( +                    title=title, +                    session_id=self.ide.session_id, +                    date_created=str(time.time()) +                ) + +            create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step( +                DisplayErrorStep(e=e))) +          if len(self._main_user_input_queue) > 1:              return diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 2553850f..a33d777e 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,5 +1,5 @@  import json -from typing import Coroutine, Dict, List, Literal, Union +from typing import Coroutine, Dict, List, Literal, Optional, Union  from pydantic.schema import schema @@ -253,6 +253,12 @@ class ContextItem(BaseModel):      editable: bool = False +class SessionInfo(ContinueBaseModel): +    session_id: str +    title: str +    date_created: str + +  class FullState(ContinueBaseModel):      """A full state of the program, including the history"""      history: History @@ -261,6 +267,7 @@ class FullState(ContinueBaseModel):      slash_commands: List[SlashCommandDescription]      adding_highlighted_code: bool      selected_context_items: List[ContextItem] +    session_info: Optional[SessionInfo] = None  class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 83a472ad..01b594cf 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -32,6 +32,15 @@ def getSessionFilePath(session_id: str):      return path +def getSessionsListFilePath(): +    path = os.path.join(getSessionsFolderPath(), "sessions.json") +    os.makedirs(os.path.dirname(path), exist_ok=True) +    if not os.path.exists(path): +        with open(path, 'w') as f: +            f.write("[]") +    return path + +  def getConfigFilePath() -> str:      path = os.path.join(getGlobalFolderPath(), "config.py")      os.makedirs(os.path.dirname(path), exist_ok=True) diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index 60c910bb..0f66ad8d 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -23,7 +23,6 @@ class PostHogLogger:          self.posthog = Posthog(self.api_key, host='https://app.posthog.com')      def setup(self, unique_id: str, allow_anonymous_telemetry: bool): -        logger.debug(f"Setting unique_id as {unique_id}")          self.unique_id = unique_id or "NO_UNIQUE_ID"          self.allow_anonymous_telemetry = allow_anonymous_telemetry or True diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 2166bc37..4262ac55 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -1,7 +1,7 @@  from .main import *  from .filesystem import RangeInFile, FileEdit  from .filesystem_edit import FileEditWithFullContents -from ..core.main import History, HistoryNode, FullState +from ..core.main import History, HistoryNode, FullState, SessionInfo  from ..core.context import ContextItem  from pydantic import schema_json_of  import os @@ -13,7 +13,7 @@ MODELS_TO_GENERATE = [  ] + [      FileEditWithFullContents  ] + [ -    History, HistoryNode, FullState +    History, HistoryNode, FullState, SessionInfo  ] + [      ContextItem  ] diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 7c89c5c2..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio  import json  from fastapi import Depends, Header, WebSocket, APIRouter  from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar  from pydantic import BaseModel  import traceback  from uvicorn.main import Server @@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              self.on_show_logs_at_index(data["index"])          elif message_type == "select_context_item":              self.select_context_item(data["id"], data["query"]) +        elif message_type == "load_session": +            self.load_session(data.get("session_id", None))      def on_main_input(self, input: str):          # Do something with user input @@ -154,6 +156,14 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          create_async_task(              self.session.autopilot.select_context_item(id, query), self.on_error) +    def load_session(self, session_id: Optional[str] = None): +        async def load_and_tell_to_reconnect(): +            new_session_id = await session_manager.load_session(self.session.session_id, session_id) +            await self._send_json("reconnect_at_session", {"session_id": new_session_id}) + +        create_async_task( +            load_and_tell_to_reconnect(), self.on_error) +  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f8dfb009..f0a3f094 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -11,13 +11,14 @@ import argparse  from .ide import router as ide_router  from .gui import router as gui_router -from .session_manager import session_manager +from .session_manager import session_manager, router as sessions_router  from ..libs.util.logging import logger  app = FastAPI()  app.include_router(ide_router)  app.include_router(gui_router) +app.include_router(sessions_router)  # Add CORS support  app.add_middleware( diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 56c92307..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,21 +1,23 @@  import os  import traceback -from fastapi import WebSocket -from typing import Any, Coroutine, Dict, Union +from fastapi import WebSocket, APIRouter +from typing import Any, Coroutine, Dict, Optional, Union  from uuid import uuid4  import json  from fastapi.websockets import WebSocketState  from ..plugins.steps.core.core import MessageStep -from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath -from ..core.main import FullState, HistoryNode +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath +from ..core.main import FullState, HistoryNode, SessionInfo  from ..core.autopilot import Autopilot  from .ide_protocol import AbstractIdeProtocolServer  from ..libs.util.create_async_task import create_async_task  from ..libs.util.errors import SessionNotFound  from ..libs.util.logging import logger +router = APIRouter(prefix="/sessions", tags=["sessions"]) +  class Session:      session_id: str @@ -47,7 +49,7 @@ class SessionManager:              raise KeyError("Session ID not recognized", session_id)          return self.sessions[session_id] -    async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: +    async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session:          logger.debug(f"New session: {session_id}")          # Load the persisted state (not being used right now) @@ -74,20 +76,9 @@ class SessionManager:          # Start the autopilot (must be after session is added to sessions) and the policy          try: -            await autopilot.start() +            await autopilot.start(full_state=full_state)          except Exception as e: -            # Have to manually add to history because autopilot isn't started -            formatted_err = '\n'.join(traceback.format_exception(e)) -            msg_step = MessageStep( -                name="Error loading context manager", message=formatted_err) -            msg_step.description = f"```\n{formatted_err}\n```" -            autopilot.history.add_node(HistoryNode( -                step=msg_step, -                observation=None, -                depth=0, -                active=False -            )) -            logger.warning(f"Error loading context manager: {e}") +            await self.on_error(e)          def on_error(e: Exception) -> Coroutine:              err_msg = '\n'.join(traceback.format_exception(e)) @@ -99,7 +90,7 @@ class SessionManager:      async def remove_session(self, session_id: str):          logger.debug(f"Removing session: {session_id}")          if session_id in self.sessions: -            if session_id in self.registered_ides: +            if session_id in self.registered_ides and self.registered_ides[session_id] is not None:                  ws_to_close = self.registered_ides[session_id].websocket                  if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED:                      await self.sessions[session_id].autopilot.ide.websocket.close() @@ -109,9 +100,37 @@ class SessionManager:      async def persist_session(self, session_id: str):          """Save the session's FullState as a json file"""          full_state = await self.sessions[session_id].autopilot.get_full_state() +        if full_state.session_info is None: +            return +          with open(getSessionFilePath(session_id), "w") as f:              json.dump(full_state.dict(), f) +        # Read and update the sessions list +        with open(getSessionsListFilePath(), "r") as f: +            sessions_list = json.load(f) + +        session_ids = [s["session_id"] for s in sessions_list] +        if session_id not in session_ids: +            sessions_list.append(full_state.session_info.dict()) + +        with open(getSessionsListFilePath(), "w") as f: +            json.dump(sessions_list, f) + +    async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: +        """Load the session's FullState from a json file""" + +        # First persist the current state +        await self.persist_session(old_session_id) + +        # Delete the old session, but keep the IDE +        ide = self.registered_ides[old_session_id] +        del self.registered_ides[old_session_id] + +        # Start the new session +        new_session = await self.new_session(ide, session_id=new_session_id) +        return new_session.session_id +      def register_websocket(self, session_id: str, ws: WebSocket):          self.sessions[session_id].ws = ws          logger.debug(f"Registered websocket for session {session_id}") @@ -130,3 +149,15 @@ class SessionManager:  session_manager = SessionManager() + + +@router.get("/list") +async def list_sessions(): +    """List all sessions""" +    sessions_list_file = getSessionsListFilePath() +    if not os.path.exists(sessions_list_file): +        print("Returning empty sessions list") +        return [] +    sessions = json.load(open(sessions_list_file, "r")) +    print("Returning sessions list: ", sessions) +    return sessions | 
