diff options
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 214 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 159 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 112 |
3 files changed, 331 insertions, 154 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 2c58c6f4..f7808335 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,29 +1,44 @@ -from functools import cached_property -import traceback import time -from typing import Callable, Coroutine, Dict, List, Optional, Union +import traceback +from functools import cached_property +from typing import Callable, Coroutine, Dict, List, Optional + from aiohttp import ClientPayloadError +from openai import error as openai_errors from pydantic import root_validator +from ..libs.util.create_async_task import create_async_task +from ..libs.util.logging import logger +from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes +from ..libs.util.telemetry import posthog_logger +from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents -from .observation import Observation, InternalErrorObservation -from .context import ContextManager -from ..plugins.policies.default import DefaultPolicy +from ..models.main import ContinueBaseModel from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider +from ..plugins.policies.default import DefaultPolicy +from ..plugins.steps.core.core import ( + DisplayErrorStep, + ManualEditStep, + ReversibleStep, + UserInputStep, +) from ..server.ide_protocol import AbstractIdeProtocolServer -from ..libs.util.queue import AsyncSubscriptionQueue -from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode -from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep +from .context import ContextManager +from .main import ( + Context, + ContinueCustomException, + FullState, + History, + HistoryNode, + Policy, + SessionInfo, + Step, +) +from .observation import InternalErrorObservation, Observation from .sdk import ContinueSDK -from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback -from openai import error as openai_errors -from ..libs.util.create_async_task import create_async_task -from ..libs.util.telemetry import posthog_logger -from ..libs.util.logging import logger def get_error_title(e: Exception) -> str: @@ -33,18 +48,23 @@ def get_error_title(e: Exception) -> str: return "This OpenAI API key has been rate limited. Please try again." elif isinstance(e, openai_errors.Timeout): return "OpenAI timed out. Please try again." - elif isinstance(e, openai_errors.InvalidRequestError) and e.code == "context_length_exceeded": + elif ( + isinstance(e, openai_errors.InvalidRequestError) + and e.code == "context_length_exceeded" + ): return e._message elif isinstance(e, ClientPayloadError): return "The request to OpenAI failed. Please try again." elif isinstance(e, openai_errors.APIConnectionError): - return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" + return 'The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' elif isinstance(e, openai_errors.InvalidRequestError): - return 'Invalid request sent to OpenAI. Please try again.' + return "Invalid request sent to OpenAI. Please try again." elif "rate_limit_ip_middleware" in e.__str__(): - return 'You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings.' + return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings." elif e.__str__().startswith("Cannot connect to host"): - return "The request failed. Please check your internet connection and try again." + return ( + "The request failed. Please check your internet connection and try again." + ) return e.__str__() or e.__repr__() @@ -78,10 +98,13 @@ class Autopilot(ContinueBaseModel): # Load documents into the search index logger.debug("Starting context manager") await self.context_manager.start( - self.continue_sdk.config.context_providers + [ + self.continue_sdk.config.context_providers + + [ HighlightedCodeContextProvider(ide=self.ide), - FileContextProvider(workspace_dir=self.ide.workspace_directory) - ], self.ide.workspace_directory) + FileContextProvider(workspace_dir=self.ide.workspace_directory), + ], + self.ide.workspace_directory, + ) if full_state is not None: self.history = full_state.history @@ -95,9 +118,9 @@ class Autopilot(ContinueBaseModel): @root_validator(pre=True) def fill_in_values(cls, values): - full_state: FullState = values.get('full_state') + full_state: FullState = values.get("full_state") if full_state is not None: - values['history'] = full_state.history + values["history"] = full_state.history return values async def get_full_state(self) -> FullState: @@ -107,18 +130,37 @@ class Autopilot(ContinueBaseModel): user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ - "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False, - selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], - session_info=self.session_info + "code" + ].adding_highlighted_code + if "code" in self.context_manager.context_providers + else False, + selected_context_items=await self.context_manager.get_selected_items() + if self.context_manager is not None + else [], + session_info=self.session_info, ) self.full_state = full_state return full_state def get_available_slash_commands(self) -> List[Dict]: - custom_commands = list(map(lambda x: { - "name": x.name, "description": x.description}, self.continue_sdk.config.custom_commands)) or [] - slash_commands = list(map(lambda x: { - "name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] + custom_commands = ( + list( + map( + lambda x: {"name": x.name, "description": x.description}, + self.continue_sdk.config.custom_commands, + ) + ) + or [] + ) + slash_commands = ( + list( + map( + lambda x: {"name": x.name, "description": x.description}, + self.continue_sdk.config.slash_commands, + ) + ) + or [] + ) return custom_commands + slash_commands async def clear_history(self): @@ -182,13 +224,16 @@ class Autopilot(ContinueBaseModel): step = tb_step.step({"output": output, **tb_step.params}) await self._run_singular_step(step) - async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): + async def handle_highlighted_code( + self, range_in_files: List[RangeInFileWithContents] + ): if "code" not in self.context_manager.context_providers: return # Add to context manager await self.context_manager.context_providers["code"].handle_highlighted_code( - range_in_files) + range_in_files + ) await self.update_subscribers() @@ -205,6 +250,23 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() + async def edit_step_at_index(self, user_input: str, index: int): + step_to_rerun = self.history.timeline[index].step.copy() + step_to_rerun.user_input = user_input + + # Halt the agent's currently running jobs (delete them) + while len(self.history.timeline) > index: + # Remove from timeline + node_to_delete = self.history.timeline.pop() + # Delete so it is stopped if in the middle of running + node_to_delete.deleted = True + + self.history.current_index = index - 1 + await self.update_subscribers() + + # Rerun from the current step + await self.run_from_step(step_to_rerun) + async def delete_context_with_ids(self, ids: List[str]): await self.context_manager.delete_context_with_ids(ids) await self.update_subscribers() @@ -213,7 +275,11 @@ class Autopilot(ContinueBaseModel): if "code" not in self.context_manager.context_providers: return - self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code + self.context_manager.context_providers[ + "code" + ].adding_highlighted_code = not self.context_manager.context_providers[ + "code" + ].adding_highlighted_code await self.update_subscribers() async def set_editing_at_ids(self, ids: List[str]): @@ -223,7 +289,9 @@ class Autopilot(ContinueBaseModel): await self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() - async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: + async def _run_singular_step( + self, step: "Step", is_future_step: bool = False + ) -> Coroutine[Observation, None, None]: # Allow config to set disallowed steps if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: return None @@ -239,19 +307,22 @@ class Autopilot(ContinueBaseModel): # i -= 1 posthog_logger.capture_event( - 'step run', {'step_name': step.name, 'params': step.dict()}) + "step run", {"step_name": step.name, "params": step.dict()} + ) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep if len(self._manual_edits_buffer) > 0: manualEditsStep = ManualEditStep.from_sequence( - self._manual_edits_buffer) + self._manual_edits_buffer + ) self._manual_edits_buffer = [] await self._run_singular_step(manualEditsStep) # Update history - do this first so we get top-first tree ordering - index_of_history_node = self.history.add_node(HistoryNode( - step=step, observation=None, depth=self._step_depth)) + index_of_history_node = self.history.add_node( + HistoryNode(step=step, observation=None, depth=self._step_depth) + ) # Call all subscribed callbacks await self.update_subscribers() @@ -263,28 +334,43 @@ class Autopilot(ContinueBaseModel): try: observation = await step(self.continue_sdk) except Exception as e: - if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted: + if ( + index_of_history_node >= len(self.history.timeline) + or self.history.timeline[index_of_history_node].deleted + ): # If step was deleted/cancelled, don't show error or allow retry return None caught_error = True is_continue_custom_exception = issubclass( - e.__class__, ContinueCustomException) - - error_string = e.message if is_continue_custom_exception else '\n'.join( - traceback.format_exception(e)) - error_title = e.title if is_continue_custom_exception else get_error_title( - e) + e.__class__, ContinueCustomException + ) + + error_string = ( + e.message + if is_continue_custom_exception + else "\n".join(traceback.format_exception(e)) + ) + error_title = ( + e.title if is_continue_custom_exception else get_error_title(e) + ) # Attach an InternalErrorObservation to the step and unhide it. - logger.error( - f"Error while running step: \n{error_string}\n{error_title}") - posthog_logger.capture_event('step error', { - 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) + logger.error(f"Error while running step: \n{error_string}\n{error_title}") + posthog_logger.capture_event( + "step error", + { + "error_message": error_string, + "error_title": error_title, + "step_name": step.name, + "params": step.dict(), + }, + ) observation = InternalErrorObservation( - error=error_string, title=error_title) + error=error_string, title=error_title + ) # Reveal this step, but hide all of the following steps (its substeps) step_was_hidden = step.hide @@ -331,8 +417,10 @@ class Autopilot(ContinueBaseModel): # Update subscribers with new description await self.update_subscribers() - create_async_task(update_description( - ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e))) + create_async_task( + update_description(), + on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), + ) return observation @@ -384,17 +472,22 @@ class Autopilot(ContinueBaseModel): # Use the first input to create title for session info, and make the session saveable if self.session_info is None: + async def create_title(): - title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ") + title = await self.continue_sdk.models.medium.complete( + f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". The title is: ' + ) title = remove_quotes_and_escapes(title) self.session_info = SessionInfo( title=title, session_id=self.ide.session_id, - date_created=str(time.time()) + date_created=str(time.time()), ) - create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step( - DisplayErrorStep(e=e))) + create_async_task( + create_title(), + on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), + ) if len(self._main_user_input_queue) > 1: return @@ -407,8 +500,9 @@ class Autopilot(ContinueBaseModel): await self.run_from_step(UserInputStep(user_input=user_input)) while len(self._main_user_input_queue) > 0: - await self.run_from_step(UserInputStep( - user_input=self._main_user_input_queue.pop(0))) + await self.run_from_step( + UserInputStep(user_input=self._main_user_input_queue.pop(0)) + ) async def accept_refinement_input(self, user_input: str, index: int): await self._request_halt() diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 024d5cea..778f81b3 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,23 +1,36 @@ +import os import traceback from typing import Coroutine, Union -import os -import importlib -from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..libs.llm import LLM +from ..libs.util.logging import logger +from ..libs.util.paths import getConfigFilePath +from ..libs.util.telemetry import posthog_logger +from ..models.filesystem import RangeInFile +from ..models.filesystem_edit import ( + AddDirectory, + AddFile, + DeleteDirectory, + DeleteFile, + FileEdit, + FileSystemEdit, +) from ..models.main import Range +from ..plugins.steps.core.core import * +from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..server.ide_protocol import AbstractIdeProtocolServer from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig -from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory -from ..models.filesystem import RangeInFile -from ..libs.llm import LLM -from .observation import Observation -from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage -from ..plugins.steps.core.core import * -from ..libs.util.telemetry import posthog_logger -from ..libs.util.paths import getConfigFilePath +from .main import ( + ChatMessage, + Context, + ContinueCustomException, + History, + HistoryNode, + Step, +) from .models import Models -from ..libs.util.logging import logger +from .observation import Observation class Autopilot: @@ -26,6 +39,7 @@ class Autopilot: class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" + ide: AbstractIdeProtocolServer models: Models context: Context @@ -46,30 +60,29 @@ class ContinueSDK(AbstractContinueSDK): config = sdk._load_config_dot_py() sdk.config = config except Exception as e: - logger.error( - f"Failed to load config.py: {traceback.format_exception(e)}") + logger.error(f"Failed to load config.py: {traceback.format_exception(e)}") - sdk.config = ContinueConfig( - ) if sdk._last_valid_config is None else sdk._last_valid_config + sdk.config = ( + ContinueConfig() + if sdk._last_valid_config is None + else sdk._last_valid_config + ) - formatted_err = '\n'.join(traceback.format_exception(e)) + formatted_err = "\n".join(traceback.format_exception(e)) msg_step = MessageStep( - name="Invalid Continue Config File", message=formatted_err) + name="Invalid Continue Config File", message=formatted_err + ) msg_step.description = f"Falling back to default config settings due to the following error in `~/.continue/config.py`.\n```\n{formatted_err}\n```\n\nIt's possible this was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py)." - sdk.history.add_node(HistoryNode( - step=msg_step, - observation=None, - depth=0, - active=False - )) + sdk.history.add_node( + HistoryNode(step=msg_step, observation=None, depth=0, active=False) + ) await sdk.ide.setFileOpen(getConfigFilePath()) sdk.models = sdk.config.models await sdk.models.start(sdk) # When the config is loaded, setup posthog logger - posthog_logger.setup( - sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) + posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) return sdk @@ -109,8 +122,14 @@ class ContinueSDK(AbstractContinueSDK): async def run_step(self, step: Step) -> Coroutine[Observation, None, None]: return await self.__autopilot._run_singular_step(step) - async def apply_filesystem_edit(self, edit: FileSystemEdit, name: str = None, description: str = None): - return await self.run_step(FileSystemEditStep(edit=edit, description=description, **({'name': name} if name else {}))) + async def apply_filesystem_edit( + self, edit: FileSystemEdit, name: str = None, description: str = None + ): + return await self.run_step( + FileSystemEditStep( + edit=edit, description=description, **({"name": name} if name else {}) + ) + ) async def wait_for_user_input(self) -> str: return await self.__autopilot.wait_for_user_input() @@ -118,22 +137,51 @@ class ContinueSDK(AbstractContinueSDK): async def wait_for_user_confirmation(self, prompt: str): return await self.run_step(WaitForUserConfirmationStep(prompt=prompt)) - async def run(self, commands: Union[List[str], str], cwd: str = None, name: str = None, description: str = None, handle_error: bool = True) -> Coroutine[str, None, None]: + async def run( + self, + commands: Union[List[str], str], + cwd: str = None, + name: str = None, + description: str = None, + handle_error: bool = True, + ) -> Coroutine[str, None, None]: commands = commands if isinstance(commands, List) else [commands] - return (await self.run_step(ShellCommandsStep(cmds=commands, cwd=cwd, description=description, handle_error=handle_error, **({'name': name} if name else {})))).text - - async def edit_file(self, filename: str, prompt: str, name: str = None, description: str = "", range: Range = None): + return ( + await self.run_step( + ShellCommandsStep( + cmds=commands, + cwd=cwd, + description=description, + handle_error=handle_error, + **({"name": name} if name else {}), + ) + ) + ).text + + async def edit_file( + self, + filename: str, + prompt: str, + name: str = None, + description: str = "", + range: Range = None, + ): filepath = await self._ensure_absolute_path(filename) await self.ide.setFileOpen(filepath) contents = await self.ide.readFile(filepath) - await self.run_step(DefaultModelEditCodeStep( - range_in_files=[RangeInFile(filepath=filepath, range=range) if range is not None else RangeInFile.from_entire_file( - filepath, contents)], - user_input=prompt, - description=description, - **({'name': name} if name else {}) - )) + await self.run_step( + DefaultModelEditCodeStep( + range_in_files=[ + RangeInFile(filepath=filepath, range=range) + if range is not None + else RangeInFile.from_entire_file(filepath, contents) + ], + user_input=prompt, + description=description, + **({"name": name} if name else {}), + ) + ) async def append_to_file(self, filename: str, content: str): filepath = await self._ensure_absolute_path(filename) @@ -145,11 +193,15 @@ class ContinueSDK(AbstractContinueSDK): filepath = await self._ensure_absolute_path(filename) dir_name = os.path.dirname(filepath) os.makedirs(dir_name, exist_ok=True) - return await self.run_step(FileSystemEditStep(edit=AddFile(filepath=filepath, content=content))) + return await self.run_step( + FileSystemEditStep(edit=AddFile(filepath=filepath, content=content)) + ) async def delete_file(self, filename: str): filename = await self._ensure_absolute_path(filename) - return await self.run_step(FileSystemEditStep(edit=DeleteFile(filepath=filename))) + return await self.run_step( + FileSystemEditStep(edit=DeleteFile(filepath=filename)) + ) async def add_directory(self, path: str): path = await self._ensure_absolute_path(path) @@ -170,6 +222,7 @@ class ContinueSDK(AbstractContinueSDK): path = getConfigFilePath() import importlib.util + spec = importlib.util.spec_from_file_location("config", path) config = importlib.util.module_from_spec(spec) spec.loader.exec_module(config) @@ -177,24 +230,34 @@ class ContinueSDK(AbstractContinueSDK): return config.config - def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: + def get_code_context( + self, only_editing: bool = False + ) -> List[RangeInFileWithContents]: highlighted_ranges = self.__autopilot.context_manager.context_providers[ - "code"].highlighted_ranges - context = list(filter(lambda x: x.item.editing, highlighted_ranges) - ) if only_editing else highlighted_ranges + "code" + ].highlighted_ranges + context = ( + list(filter(lambda x: x.item.editing, highlighted_ranges)) + if only_editing + else highlighted_ranges + ) return [c.rif for c in context] def set_loading_message(self, message: str): # self.__autopilot.set_loading_message(message) raise NotImplementedError() - def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None): + def raise_exception( + self, message: str, title: str, with_step: Union[Step, None] = None + ): raise ContinueCustomException(message, title, with_step) async def get_chat_context(self) -> List[ChatMessage]: history_context = self.history.to_chat_history() - context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages() + context_messages: List[ + ChatMessage + ] = await self.__autopilot.context_manager.get_chat_messages() # Insert at the end, but don't insert after latest user message or function call for msg in context_messages: diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 49d46be3..7497e777 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,19 +1,20 @@ import asyncio import json -from fastapi import Depends, Header, WebSocket, APIRouter -from starlette.websockets import WebSocketState, WebSocketDisconnect +import traceback from typing import Any, List, Optional, Type, TypeVar + +from fastapi import APIRouter, Depends, WebSocket from pydantic import BaseModel -import traceback +from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server -from .session_manager import session_manager, Session -from ..plugins.steps.core.core import DisplayErrorStep, MessageStep -from .gui_protocol import AbstractGUIProtocolServer -from ..libs.util.queue import AsyncSubscriptionQueue -from ..libs.util.telemetry import posthog_logger from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger +from ..libs.util.queue import AsyncSubscriptionQueue +from ..libs.util.telemetry import posthog_logger +from ..plugins.steps.core.core import DisplayErrorStep +from .gui_protocol import AbstractGUIProtocolServer +from .session_manager import Session, session_manager router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,19 +55,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: return - await self.websocket.send_json({ - "messageType": message_type, - "data": data - }) + await self.websocket.send_json({"messageType": message_type, "data": data}) async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: try: - return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + return await asyncio.wait_for( + self.sub_queue.get(message_type), timeout=timeout + ) except asyncio.TimeoutError: - raise Exception( - "GUI Protocol _receive_json timed out after 20 seconds") + raise Exception("GUI Protocol _receive_json timed out after 20 seconds") - async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: + async def _send_and_receive_json( + self, data: Any, resp_model: Type[T], message_type: str + ) -> T: await self._send_json(message_type, data) resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) @@ -101,78 +102,95 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": self.load_session(data.get("session_id", None)) + elif message_type == "edit_step_at_index": + self.edit_step_at_index(data.get("user_input", ""), data["index"]) def on_main_input(self, input: str): # Do something with user input create_async_task( - self.session.autopilot.accept_user_input(input), self.on_error) + self.session.autopilot.accept_user_input(input), self.on_error + ) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - create_async_task( - self.session.autopilot.reverse_to_index(index), self.on_error) + create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error) def on_step_user_input(self, input: str, index: int): create_async_task( - self.session.autopilot.give_user_input(input, index), self.on_error) + self.session.autopilot.give_user_input(input, index), self.on_error + ) def on_refinement_input(self, input: str, index: int): create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.on_error) + self.session.autopilot.accept_refinement_input(input, index), self.on_error + ) def on_retry_at_index(self, index: int): - create_async_task( - self.session.autopilot.retry_at_index(index), self.on_error) + create_async_task(self.session.autopilot.retry_at_index(index), self.on_error) def on_clear_history(self): - create_async_task( - self.session.autopilot.clear_history(), self.on_error) + create_async_task(self.session.autopilot.clear_history(), self.on_error) def on_delete_at_index(self, index: int): + create_async_task(self.session.autopilot.delete_at_index(index), self.on_error) + + def edit_step_at_index(self, user_input: str, index: int): create_async_task( - self.session.autopilot.delete_at_index(index), self.on_error) + self.session.autopilot.edit_step_at_index(user_input, index), + self.on_error, + ) def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_with_ids(ids), self.on_error) + self.session.autopilot.delete_context_with_ids(ids), self.on_error + ) def on_toggle_adding_highlighted_code(self): create_async_task( - self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) + self.session.autopilot.toggle_adding_highlighted_code(), self.on_error + ) posthog_logger.capture_event("toggle_adding_highlighted_code", {}) def on_set_editing_at_ids(self, ids: List[str]): - create_async_task( - self.session.autopilot.set_editing_at_ids(ids), self.on_error) + create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error) def on_show_logs_at_index(self, index: int): - name = f"continue_logs.txt" + name = "continue_logs.txt" logs = "\n\n############################################\n\n".join( - ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) + [ + "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" + ] + + self.session.autopilot.continue_sdk.history.timeline[index].logs + ) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) + self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error + ) posthog_logger.capture_event("show_logs_at_index", {}) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( - self.session.autopilot.select_context_item(id, query), self.on_error) + self.session.autopilot.select_context_item(id, query), self.on_error + ) def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): - new_session_id = await session_manager.load_session(self.session.session_id, session_id) - await self._send_json("reconnect_at_session", {"session_id": new_session_id}) + new_session_id = await session_manager.load_session( + self.session.session_id, session_id + ) + await self._send_json( + "reconnect_at_session", {"session_id": new_session_id} + ) - create_async_task( - load_and_tell_to_reconnect(), self.on_error) + create_async_task(load_and_tell_to_reconnect(), self.on_error) - posthog_logger.capture_event("load_session", { - "session_id": session_id - }) + posthog_logger.capture_event("load_session", {"session_id": session_id}) @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): +async def websocket_endpoint( + websocket: WebSocket, session: Session = Depends(websocket_session) +): try: logger.debug(f"Received websocket connection at url: {websocket.url}") await websocket.accept() @@ -197,14 +215,16 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we data = message["data"] protocol.handle_json(message_type, data) - except WebSocketDisconnect as e: + except WebSocketDisconnect: logger.debug("GUI websocket disconnected") except Exception as e: # Log, send to PostHog, and send to GUI logger.debug(f"ERROR in gui websocket: {e}") - err_msg = '\n'.join(traceback.format_exception(e)) - posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + err_msg = "\n".join(traceback.format_exception(e)) + posthog_logger.capture_event( + "gui_error", + {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg}, + ) await session.autopilot.ide.showMessage(err_msg) |