diff options
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/create_async_task.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 33 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/search_directory.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 100 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 20 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 4 | ||||
-rw-r--r-- | extension/package-lock.json | 4 | ||||
-rw-r--r-- | extension/package.json | 2 |
10 files changed, 108 insertions, 73 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 931cfb75..3f25e64e 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -15,7 +15,7 @@ from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode -from ..plugins.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep +from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep from .sdk import ContinueSDK from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors @@ -312,8 +312,8 @@ class Autopilot(ContinueBaseModel): # Update subscribers with new description await self.update_subscribers() - create_async_task(update_description(), - self.continue_sdk.ide.unique_id) + create_async_task(update_description( + ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e))) return observation diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 5bb88b92..4b76a121 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@ from functools import cached_property import traceback -from typing import Coroutine, Dict, Union +from typing import Coroutine, Dict, Literal, Union import os from ..plugins.steps.core.core import DefaultModelEditCodeStep diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 00e87445..4c6d3c95 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -1,4 +1,4 @@ -from typing import Coroutine, Union +from typing import Callable, Coroutine, Optional, Union import traceback from .telemetry import posthog_logger from .logging import logger @@ -7,7 +7,7 @@ import nest_asyncio nest_asyncio.apply() -def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): +def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None): """asyncio.create_task and log errors by adding a callback""" task = asyncio.create_task(coro) @@ -22,5 +22,9 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) }) + # Log the error to the GUI + if on_error is not None: + asyncio.create_task(on_error(e)) + task.add_done_callback(callback) return task diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5a81e5ee..c80cecc3 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,20 +1,22 @@ # These steps are depended upon by ContinueSDK import os -import subprocess +import json import difflib from textwrap import dedent -from typing import Coroutine, List, Literal, Union +import traceback +from typing import Any, Coroutine, List, Union +import difflib + +from pydantic import validator from ....libs.llm.ggml import GGML from ....models.main import Range -from ....libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep +from ....core.observation import Observation, TextObservation, UserInputObservation +from ....core.main import ChatMessage, ContinueCustomException, Step from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes -import difflib class ContinueSDK: @@ -41,6 +43,25 @@ class MessageStep(Step): return TextObservation(text=self.message) +class DisplayErrorStep(Step): + name: str = "Error in the Continue server" + e: Any + + class Config: + arbitrary_types_allowed = True + + @validator("e", pre=True, always=True) + def validate_e(cls, v): + if isinstance(v, Exception): + return '\n'.join(traceback.format_exception(v)) + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.e + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + raise ContinueCustomException(message=self.e, title=self.name) + + class FileSystemEditStep(ReversibleStep): edit: FileSystemEdit _diff: Union[EditDiff, None] = None diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 7d02d6fa..07b50473 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -65,5 +65,5 @@ class EditAllMatchesStep(Step): range=range_in_file.range, filename=range_in_file.filepath, prompt=self.user_request - ), sdk.ide.unique_id) for range_in_file in range_in_files] + )) for range_in_file in range_in_files] await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 2adb680e..98a5aea0 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -8,6 +8,7 @@ import traceback from uvicorn.main import Server from .session_manager import session_manager, Session +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .gui_protocol import AbstractGUIProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger @@ -70,94 +71,88 @@ class GUIProtocolServer(AbstractGUIProtocolServer): resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) + def on_error(self, e: Exception): + return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def handle_json(self, message_type: str, data: Any): - try: - if message_type == "main_input": - self.on_main_input(data["input"]) - elif message_type == "step_user_input": - self.on_step_user_input(data["input"], data["index"]) - elif message_type == "refinement_input": - self.on_refinement_input(data["input"], data["index"]) - elif message_type == "reverse_to_index": - self.on_reverse_to_index(data["index"]) - elif message_type == "retry_at_index": - self.on_retry_at_index(data["index"]) - elif message_type == "clear_history": - self.on_clear_history() - elif message_type == "delete_at_index": - self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"]) - elif message_type == "toggle_adding_highlighted_code": - self.on_toggle_adding_highlighted_code() - elif message_type == "set_editing_at_indices": - self.on_set_editing_at_indices(data["indices"]) - elif message_type == "show_logs_at_index": - self.on_show_logs_at_index(data["index"]) - elif message_type == "select_context_item": - self.select_context_item(data["id"], data["query"]) - except Exception as e: - logger.debug(e) + if message_type == "main_input": + self.on_main_input(data["input"]) + elif message_type == "step_user_input": + self.on_step_user_input(data["input"], data["index"]) + elif message_type == "refinement_input": + self.on_refinement_input(data["input"], data["index"]) + elif message_type == "reverse_to_index": + self.on_reverse_to_index(data["index"]) + elif message_type == "retry_at_index": + self.on_retry_at_index(data["index"]) + elif message_type == "clear_history": + self.on_clear_history() + elif message_type == "delete_at_index": + self.on_delete_at_index(data["index"]) + elif message_type == "delete_context_with_ids": + self.on_delete_context_with_ids(data["ids"]) + elif message_type == "toggle_adding_highlighted_code": + self.on_toggle_adding_highlighted_code() + elif message_type == "set_editing_at_indices": + self.on_set_editing_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) + elif message_type == "select_context_item": + self.select_context_item(data["id"], data["query"]) def on_main_input(self, input: str): # Do something with user input - create_async_task(self.session.autopilot.accept_user_input( - input), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.accept_user_input(input), self.on_error) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - create_async_task(self.session.autopilot.reverse_to_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.reverse_to_index(index), self.on_error) def on_step_user_input(self, input: str, index: int): create_async_task( - self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.give_user_input(input, index), self.on_error) def on_refinement_input(self, input: str, index: int): create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.accept_refinement_input(input, index), self.on_error) def on_retry_at_index(self, index: int): create_async_task( - self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.retry_at_index(index), self.on_error) def on_clear_history(self): - create_async_task(self.session.autopilot.clear_history( - ), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.clear_history(), self.on_error) def on_delete_at_index(self, index: int): - create_async_task(self.session.autopilot.delete_at_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.delete_at_index(index), self.on_error) def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_with_ids( - ids), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.delete_context_with_ids(ids), self.on_error) def on_toggle_adding_highlighted_code(self): create_async_task( - self.session.autopilot.toggle_adding_highlighted_code( - ), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) def on_set_editing_at_indices(self, indices: List[int]): create_async_task( - self.session.autopilot.set_editing_at_indices( - indices), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.set_editing_at_indices(indices), self.on_error) def on_show_logs_at_index(self, index: int): name = f"continue_logs.txt" logs = "\n\n############################################\n\n".join( ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( - self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.select_context_item(id, query), self.on_error) @router.websocket("/ws") @@ -189,9 +184,14 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we except WebSocketDisconnect as e: logger.debug("GUI websocket disconnected") except Exception as e: + # Log, send to PostHog, and send to GUI logger.debug(f"ERROR in gui websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await protocol.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: logger.debug("Closing gui websocket") diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 8a269cb7..e4c07029 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import traceback import asyncio +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .meilisearch_server import start_meilisearch from ..libs.util.telemetry import posthog_logger from ..libs.util.queue import AsyncSubscriptionQueue @@ -279,6 +280,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer): # This is where you might have triggers: plugins can subscribe to certian events # like file changes, tracebacks, etc... + def on_error(self, e: Exception): + return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def onAcceptRejectSuggestion(self, accepted: bool): posthog_logger.capture_event("accept_reject_suggestion", { "accepted": accepted @@ -311,22 +315,22 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): - create_async_task(autopilot.delete_at_index(index), self.unique_id) + create_async_task(autopilot.delete_at_index(index), self.on_error) def onCommandOutput(self, output: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.handle_command_output(output), self.unique_id) + autopilot.handle_command_output(output), self.on_error) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): if autopilot := self.__get_autopilot(): create_async_task(autopilot.handle_highlighted_code( - range_in_files), self.unique_id) + range_in_files), self.on_error) def onMainUserInput(self, input: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.accept_user_input(input), self.unique_id) + autopilot.accept_user_input(input), self.on_error) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: @@ -459,7 +463,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): logger.debug(f"Received IDE message: {message_type}") create_async_task( - ideProtocolServer.handle_json(message_type, data)) + ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error) ideProtocolServer = IdeProtocolServer(session_manager, websocket) if session_id is not None: @@ -480,8 +484,12 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): logger.debug("IDE wbsocket disconnected") except Exception as e: logger.debug(f"Error in ide websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: logger.debug("Closing ide websocket") diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index d30411cd..cf46028f 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -6,6 +6,7 @@ import json from fastapi.websockets import WebSocketState +from ..plugins.steps.core.core import DisplayErrorStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER @@ -83,7 +84,8 @@ class SessionManager: }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy()) + create_async_task(autopilot.run_policy( + ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) return session async def remove_session(self, session_id: str): diff --git a/extension/package-lock.json b/extension/package-lock.json index 30b9952c..18c80fed 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.0.224", + "version": "0.0.225", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "continue", - "version": "0.0.224", + "version": "0.0.225", "license": "Apache-2.0", "dependencies": { "@electron/rebuild": "^3.2.10", diff --git a/extension/package.json b/extension/package.json index 9b9daef6..b2c94593 100644 --- a/extension/package.json +++ b/extension/package.json @@ -14,7 +14,7 @@ "displayName": "Continue", "pricing": "Free", "description": "The open-source coding autopilot", - "version": "0.0.224", + "version": "0.0.225", "publisher": "Continue", "engines": { "vscode": "^1.67.0" |