diff options
Diffstat (limited to 'continuedev/src')
10 files changed, 203 insertions, 29 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index f3a17d47..256f3439 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,10 +1,11 @@ from functools import cached_property import traceback import time -from typing import Callable, Coroutine, Dict, List, Union +from typing import Callable, Coroutine, Dict, List, Optional, Union from aiohttp import ClientPayloadError from pydantic import root_validator +from ..libs.util.strings import remove_quotes_and_escapes from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation @@ -15,7 +16,7 @@ from ..plugins.context_providers.highlighted_code import HighlightedCodeContextP from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode +from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep from .sdk import ContinueSDK from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback @@ -53,7 +54,8 @@ class Autopilot(ContinueBaseModel): policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() - full_state: Union[FullState, None] = None + full_state: Optional[FullState] = None + session_info: Optional[SessionInfo] = None context_manager: ContextManager = ContextManager() continue_sdk: ContinueSDK = None @@ -68,7 +70,7 @@ class Autopilot(ContinueBaseModel): started: bool = False - async def start(self): + async def start(self, full_state: Optional[FullState] = None): self.continue_sdk = await ContinueSDK.create(self) if override_policy := self.continue_sdk.config.policy_override: self.policy = override_policy @@ -84,6 +86,12 @@ class Autopilot(ContinueBaseModel): logger.debug("Loading index") create_async_task(self.context_manager.load_index( self.ide.workspace_directory)) + + if full_state is not None: + self.history = full_state.history + self.context_manager.context_providers["code"].adding_highlighted_code = full_state.adding_highlighted_code + self.session_info = full_state.session_info + self.started = True class Config: @@ -106,6 +114,7 @@ class Autopilot(ContinueBaseModel): adding_highlighted_code=self.context_manager.context_providers[ "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False, selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], + session_info=self.session_info ) self.full_state = full_state return full_state @@ -369,6 +378,20 @@ class Autopilot(ContinueBaseModel): self._main_user_input_queue.append(user_input) await self.update_subscribers() + # Use the first input to create title for session info, and make the session saveable + if self.session_info is None: + async def create_title(): + title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ") + title = remove_quotes_and_escapes(title) + self.session_info = SessionInfo( + title=title, + session_id=self.ide.session_id, + date_created=str(time.time()) + ) + + create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep(e=e))) + if len(self._main_user_input_queue) > 1: return diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 2553850f..a33d777e 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,5 +1,5 @@ import json -from typing import Coroutine, Dict, List, Literal, Union +from typing import Coroutine, Dict, List, Literal, Optional, Union from pydantic.schema import schema @@ -253,6 +253,12 @@ class ContextItem(BaseModel): editable: bool = False +class SessionInfo(ContinueBaseModel): + session_id: str + title: str + date_created: str + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" history: History @@ -261,6 +267,7 @@ class FullState(ContinueBaseModel): slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool selected_context_items: List[ContextItem] + session_info: Optional[SessionInfo] = None class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/constants/default_config.py b/continuedev/src/continuedev/libs/constants/default_config.py index d3922091..f3b19f89 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py +++ b/continuedev/src/continuedev/libs/constants/default_config.py @@ -21,6 +21,7 @@ from continuedev.src.continuedev.plugins.steps.clear_history import ClearHistory from continuedev.src.continuedev.plugins.steps.feedback import FeedbackStep from continuedev.src.continuedev.plugins.steps.comment_code import CommentCodeStep from continuedev.src.continuedev.plugins.steps.main import EditHighlightedCodeStep +from continuedev.src.continuedev.plugins.context_providers.search import SearchContextProvider class CommitMessageStep(Step): @@ -122,6 +123,7 @@ config = ContinueConfig( # GoogleContextProvider( # serper_api_key="<your serper.dev api key>" # ) + SearchContextProvider() ], # Policies hold the main logic that decides which Step to take next diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 83a472ad..01b594cf 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -32,6 +32,15 @@ def getSessionFilePath(session_id: str): return path +def getSessionsListFilePath(): + path = os.path.join(getSessionsFolderPath(), "sessions.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, 'w') as f: + f.write("[]") + return path + + def getConfigFilePath() -> str: path = os.path.join(getGlobalFolderPath(), "config.py") os.makedirs(os.path.dirname(path), exist_ok=True) diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index 60c910bb..0f66ad8d 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -23,7 +23,6 @@ class PostHogLogger: self.posthog = Posthog(self.api_key, host='https://app.posthog.com') def setup(self, unique_id: str, allow_anonymous_telemetry: bool): - logger.debug(f"Setting unique_id as {unique_id}") self.unique_id = unique_id or "NO_UNIQUE_ID" self.allow_anonymous_telemetry = allow_anonymous_telemetry or True diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 2166bc37..4262ac55 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -1,7 +1,7 @@ from .main import * from .filesystem import RangeInFile, FileEdit from .filesystem_edit import FileEditWithFullContents -from ..core.main import History, HistoryNode, FullState +from ..core.main import History, HistoryNode, FullState, SessionInfo from ..core.context import ContextItem from pydantic import schema_json_of import os @@ -13,7 +13,7 @@ MODELS_TO_GENERATE = [ ] + [ FileEditWithFullContents ] + [ - History, HistoryNode, FullState + History, HistoryNode, FullState, SessionInfo ] + [ ContextItem ] diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py new file mode 100644 index 00000000..17f2660c --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/search.py @@ -0,0 +1,92 @@ +import os +from typing import List +from ripgrepy import Ripgrepy + +from .util import remove_meilisearch_disallowed_chars +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...core.context import ContextProvider + + +class SearchContextProvider(ContextProvider): + title = "search" + + SEARCH_CONTEXT_ITEM_ID = "search" + + workspace_dir: str = None + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Search", + description="Search the workspace for all matches of an exact string (e.g. '@search console.log')", + id=ContextItemId( + provider_title=self.title, + item_id=self.SEARCH_CONTEXT_ITEM_ID + ) + ) + ) + + def _get_rg_path(self): + if os.name == 'nt': + rg_path = f"C:\\Users\\{os.getlogin()}\\AppData\\Local\\Programs\\Microsoft VS Code\\resources\\app\\node_modules.asar.unpacked\\vscode-ripgrep\\bin\\rg.exe" + elif os.name == 'posix': + if 'darwin' in os.sys.platform: + rg_path = "/Applications/Visual Studio Code.app/Contents/Resources/app/node_modules.asar.unpacked/vscode-ripgrep/bin/rg" + else: + rg_path = "/usr/share/code/resources/app/node_modules.asar.unpacked/vscode-ripgrep/bin/rg" + else: + rg_path = "rg" + + if not os.path.exists(rg_path): + rg_path = "rg" + + return rg_path + + async def _search(self, query: str) -> str: + rg = Ripgrepy(query, self.workspace_dir, rg_path=self._get_rg_path()) + results = rg.I().context(2).run() + return f"Search results in workspace for '{query}':\n\n{results}" + + # Custom display below - TODO + + # Gather results per file + file_to_matches = {} + for result in results: + if result["type"] == "match": + data = result["data"] + filepath = data["path"]["text"] + if filepath not in file_to_matches: + file_to_matches[filepath] = [] + + line_num_and_line = f"{data['line_number']}: {data['lines']['text']}" + file_to_matches[filepath].append(line_num_and_line) + + # Format results + content = f"Search results in workspace for '{query}':\n\n" + for filepath, matches in file_to_matches.items(): + content += f"{filepath}\n" + for match in matches: + content += f"{match}\n" + content += "\n" + + return content + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.item_id == self.SEARCH_CONTEXT_ITEM_ID: + raise Exception("Invalid item id") + + query = query.lstrip("search ") + results = await self._search(query) + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = results + ctx_item.description.name = f"Search: '{query}'" + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( + query) + return ctx_item diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 7c89c5c2..4470999a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,7 +2,7 @@ import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect -from typing import Any, List, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from pydantic import BaseModel import traceback from uvicorn.main import Server @@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_show_logs_at_index(data["index"]) elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) + elif message_type == "load_session": + self.load_session(data.get("session_id", None)) def on_main_input(self, input: str): # Do something with user input @@ -154,6 +156,14 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( self.session.autopilot.select_context_item(id, query), self.on_error) + def load_session(self, session_id: Optional[str] = None): + async def load_and_tell_to_reconnect(): + new_session_id = await session_manager.load_session(self.session.session_id, session_id) + await self._send_json("reconnect_at_session", {"session_id": new_session_id}) + + create_async_task( + load_and_tell_to_reconnect(), self.on_error) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f8dfb009..f0a3f094 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -11,13 +11,14 @@ import argparse from .ide import router as ide_router from .gui import router as gui_router -from .session_manager import session_manager +from .session_manager import session_manager, router as sessions_router from ..libs.util.logging import logger app = FastAPI() app.include_router(ide_router) app.include_router(gui_router) +app.include_router(sessions_router) # Add CORS support app.add_middleware( diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 56c92307..cde0344e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,21 +1,23 @@ import os import traceback -from fastapi import WebSocket -from typing import Any, Coroutine, Dict, Union +from fastapi import WebSocket, APIRouter +from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 import json from fastapi.websockets import WebSocketState from ..plugins.steps.core.core import MessageStep -from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath -from ..core.main import FullState, HistoryNode +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath +from ..core.main import FullState, HistoryNode, SessionInfo from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from ..libs.util.errors import SessionNotFound from ..libs.util.logging import logger +router = APIRouter(prefix="/sessions", tags=["sessions"]) + class Session: session_id: str @@ -47,7 +49,7 @@ class SessionManager: raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -74,20 +76,9 @@ class SessionManager: # Start the autopilot (must be after session is added to sessions) and the policy try: - await autopilot.start() + await autopilot.start(full_state=full_state) except Exception as e: - # Have to manually add to history because autopilot isn't started - formatted_err = '\n'.join(traceback.format_exception(e)) - msg_step = MessageStep( - name="Error loading context manager", message=formatted_err) - msg_step.description = f"```\n{formatted_err}\n```" - autopilot.history.add_node(HistoryNode( - step=msg_step, - observation=None, - depth=0, - active=False - )) - logger.warning(f"Error loading context manager: {e}") + await self.on_error(e) def on_error(e: Exception) -> Coroutine: err_msg = '\n'.join(traceback.format_exception(e)) @@ -99,7 +90,7 @@ class SessionManager: async def remove_session(self, session_id: str): logger.debug(f"Removing session: {session_id}") if session_id in self.sessions: - if session_id in self.registered_ides: + if session_id in self.registered_ides and self.registered_ides[session_id] is not None: ws_to_close = self.registered_ides[session_id].websocket if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: await self.sessions[session_id].autopilot.ide.websocket.close() @@ -109,9 +100,37 @@ class SessionManager: async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" full_state = await self.sessions[session_id].autopilot.get_full_state() + if full_state.session_info is None: + return + with open(getSessionFilePath(session_id), "w") as f: json.dump(full_state.dict(), f) + # Read and update the sessions list + with open(getSessionsListFilePath(), "r") as f: + sessions_list = json.load(f) + + session_ids = [s["session_id"] for s in sessions_list] + if session_id not in session_ids: + sessions_list.append(full_state.session_info.dict()) + + with open(getSessionsListFilePath(), "w") as f: + json.dump(sessions_list, f) + + async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: + """Load the session's FullState from a json file""" + + # First persist the current state + await self.persist_session(old_session_id) + + # Delete the old session, but keep the IDE + ide = self.registered_ides[old_session_id] + del self.registered_ides[old_session_id] + + # Start the new session + new_session = await self.new_session(ide, session_id=new_session_id) + return new_session.session_id + def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws logger.debug(f"Registered websocket for session {session_id}") @@ -130,3 +149,15 @@ class SessionManager: session_manager = SessionManager() + + +@router.get("/list") +async def list_sessions(): + """List all sessions""" + sessions_list_file = getSessionsListFilePath() + if not os.path.exists(sessions_list_file): + print("Returning empty sessions list") + return [] + sessions = json.load(open(sessions_list_file, "r")) + print("Returning sessions list: ", sessions) + return sessions |