summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py214
-rw-r--r--continuedev/src/continuedev/core/sdk.py159
-rw-r--r--continuedev/src/continuedev/server/gui.py112
3 files changed, 331 insertions, 154 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 2c58c6f4..f7808335 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,29 +1,44 @@
-from functools import cached_property
-import traceback
import time
-from typing import Callable, Coroutine, Dict, List, Optional, Union
+import traceback
+from functools import cached_property
+from typing import Callable, Coroutine, Dict, List, Optional
+
from aiohttp import ClientPayloadError
+from openai import error as openai_errors
from pydantic import root_validator
+from ..libs.util.create_async_task import create_async_task
+from ..libs.util.logging import logger
+from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.strings import remove_quotes_and_escapes
+from ..libs.util.telemetry import posthog_logger
+from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
-from .observation import Observation, InternalErrorObservation
-from .context import ContextManager
-from ..plugins.policies.default import DefaultPolicy
+from ..models.main import ContinueBaseModel
from ..plugins.context_providers.file import FileContextProvider
from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider
+from ..plugins.policies.default import DefaultPolicy
+from ..plugins.steps.core.core import (
+ DisplayErrorStep,
+ ManualEditStep,
+ ReversibleStep,
+ UserInputStep,
+)
from ..server.ide_protocol import AbstractIdeProtocolServer
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..models.main import ContinueBaseModel
-from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode
-from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep
+from .context import ContextManager
+from .main import (
+ Context,
+ ContinueCustomException,
+ FullState,
+ History,
+ HistoryNode,
+ Policy,
+ SessionInfo,
+ Step,
+)
+from .observation import InternalErrorObservation, Observation
from .sdk import ContinueSDK
-from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
-from openai import error as openai_errors
-from ..libs.util.create_async_task import create_async_task
-from ..libs.util.telemetry import posthog_logger
-from ..libs.util.logging import logger
def get_error_title(e: Exception) -> str:
@@ -33,18 +48,23 @@ def get_error_title(e: Exception) -> str:
return "This OpenAI API key has been rate limited. Please try again."
elif isinstance(e, openai_errors.Timeout):
return "OpenAI timed out. Please try again."
- elif isinstance(e, openai_errors.InvalidRequestError) and e.code == "context_length_exceeded":
+ elif (
+ isinstance(e, openai_errors.InvalidRequestError)
+ and e.code == "context_length_exceeded"
+ ):
return e._message
elif isinstance(e, ClientPayloadError):
return "The request to OpenAI failed. Please try again."
elif isinstance(e, openai_errors.APIConnectionError):
- return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""
+ return 'The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'
elif isinstance(e, openai_errors.InvalidRequestError):
- return 'Invalid request sent to OpenAI. Please try again.'
+ return "Invalid request sent to OpenAI. Please try again."
elif "rate_limit_ip_middleware" in e.__str__():
- return 'You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings.'
+ return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings."
elif e.__str__().startswith("Cannot connect to host"):
- return "The request failed. Please check your internet connection and try again."
+ return (
+ "The request failed. Please check your internet connection and try again."
+ )
return e.__str__() or e.__repr__()
@@ -78,10 +98,13 @@ class Autopilot(ContinueBaseModel):
# Load documents into the search index
logger.debug("Starting context manager")
await self.context_manager.start(
- self.continue_sdk.config.context_providers + [
+ self.continue_sdk.config.context_providers
+ + [
HighlightedCodeContextProvider(ide=self.ide),
- FileContextProvider(workspace_dir=self.ide.workspace_directory)
- ], self.ide.workspace_directory)
+ FileContextProvider(workspace_dir=self.ide.workspace_directory),
+ ],
+ self.ide.workspace_directory,
+ )
if full_state is not None:
self.history = full_state.history
@@ -95,9 +118,9 @@ class Autopilot(ContinueBaseModel):
@root_validator(pre=True)
def fill_in_values(cls, values):
- full_state: FullState = values.get('full_state')
+ full_state: FullState = values.get("full_state")
if full_state is not None:
- values['history'] = full_state.history
+ values["history"] = full_state.history
return values
async def get_full_state(self) -> FullState:
@@ -107,18 +130,37 @@ class Autopilot(ContinueBaseModel):
user_input_queue=self._main_user_input_queue,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
- "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,
- selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],
- session_info=self.session_info
+ "code"
+ ].adding_highlighted_code
+ if "code" in self.context_manager.context_providers
+ else False,
+ selected_context_items=await self.context_manager.get_selected_items()
+ if self.context_manager is not None
+ else [],
+ session_info=self.session_info,
)
self.full_state = full_state
return full_state
def get_available_slash_commands(self) -> List[Dict]:
- custom_commands = list(map(lambda x: {
- "name": x.name, "description": x.description}, self.continue_sdk.config.custom_commands)) or []
- slash_commands = list(map(lambda x: {
- "name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or []
+ custom_commands = (
+ list(
+ map(
+ lambda x: {"name": x.name, "description": x.description},
+ self.continue_sdk.config.custom_commands,
+ )
+ )
+ or []
+ )
+ slash_commands = (
+ list(
+ map(
+ lambda x: {"name": x.name, "description": x.description},
+ self.continue_sdk.config.slash_commands,
+ )
+ )
+ or []
+ )
return custom_commands + slash_commands
async def clear_history(self):
@@ -182,13 +224,16 @@ class Autopilot(ContinueBaseModel):
step = tb_step.step({"output": output, **tb_step.params})
await self._run_singular_step(step)
- async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
+ async def handle_highlighted_code(
+ self, range_in_files: List[RangeInFileWithContents]
+ ):
if "code" not in self.context_manager.context_providers:
return
# Add to context manager
await self.context_manager.context_providers["code"].handle_highlighted_code(
- range_in_files)
+ range_in_files
+ )
await self.update_subscribers()
@@ -205,6 +250,23 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
+ async def edit_step_at_index(self, user_input: str, index: int):
+ step_to_rerun = self.history.timeline[index].step.copy()
+ step_to_rerun.user_input = user_input
+
+ # Halt the agent's currently running jobs (delete them)
+ while len(self.history.timeline) > index:
+ # Remove from timeline
+ node_to_delete = self.history.timeline.pop()
+ # Delete so it is stopped if in the middle of running
+ node_to_delete.deleted = True
+
+ self.history.current_index = index - 1
+ await self.update_subscribers()
+
+ # Rerun from the current step
+ await self.run_from_step(step_to_rerun)
+
async def delete_context_with_ids(self, ids: List[str]):
await self.context_manager.delete_context_with_ids(ids)
await self.update_subscribers()
@@ -213,7 +275,11 @@ class Autopilot(ContinueBaseModel):
if "code" not in self.context_manager.context_providers:
return
- self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code
+ self.context_manager.context_providers[
+ "code"
+ ].adding_highlighted_code = not self.context_manager.context_providers[
+ "code"
+ ].adding_highlighted_code
await self.update_subscribers()
async def set_editing_at_ids(self, ids: List[str]):
@@ -223,7 +289,9 @@ class Autopilot(ContinueBaseModel):
await self.context_manager.context_providers["code"].set_editing_at_ids(ids)
await self.update_subscribers()
- async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
+ async def _run_singular_step(
+ self, step: "Step", is_future_step: bool = False
+ ) -> Coroutine[Observation, None, None]:
# Allow config to set disallowed steps
if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps:
return None
@@ -239,19 +307,22 @@ class Autopilot(ContinueBaseModel):
# i -= 1
posthog_logger.capture_event(
- 'step run', {'step_name': step.name, 'params': step.dict()})
+ "step run", {"step_name": step.name, "params": step.dict()}
+ )
if not is_future_step:
# Check manual edits buffer, clear out if needed by creating a ManualEditStep
if len(self._manual_edits_buffer) > 0:
manualEditsStep = ManualEditStep.from_sequence(
- self._manual_edits_buffer)
+ self._manual_edits_buffer
+ )
self._manual_edits_buffer = []
await self._run_singular_step(manualEditsStep)
# Update history - do this first so we get top-first tree ordering
- index_of_history_node = self.history.add_node(HistoryNode(
- step=step, observation=None, depth=self._step_depth))
+ index_of_history_node = self.history.add_node(
+ HistoryNode(step=step, observation=None, depth=self._step_depth)
+ )
# Call all subscribed callbacks
await self.update_subscribers()
@@ -263,28 +334,43 @@ class Autopilot(ContinueBaseModel):
try:
observation = await step(self.continue_sdk)
except Exception as e:
- if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted:
+ if (
+ index_of_history_node >= len(self.history.timeline)
+ or self.history.timeline[index_of_history_node].deleted
+ ):
# If step was deleted/cancelled, don't show error or allow retry
return None
caught_error = True
is_continue_custom_exception = issubclass(
- e.__class__, ContinueCustomException)
-
- error_string = e.message if is_continue_custom_exception else '\n'.join(
- traceback.format_exception(e))
- error_title = e.title if is_continue_custom_exception else get_error_title(
- e)
+ e.__class__, ContinueCustomException
+ )
+
+ error_string = (
+ e.message
+ if is_continue_custom_exception
+ else "\n".join(traceback.format_exception(e))
+ )
+ error_title = (
+ e.title if is_continue_custom_exception else get_error_title(e)
+ )
# Attach an InternalErrorObservation to the step and unhide it.
- logger.error(
- f"Error while running step: \n{error_string}\n{error_title}")
- posthog_logger.capture_event('step error', {
- 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
+ logger.error(f"Error while running step: \n{error_string}\n{error_title}")
+ posthog_logger.capture_event(
+ "step error",
+ {
+ "error_message": error_string,
+ "error_title": error_title,
+ "step_name": step.name,
+ "params": step.dict(),
+ },
+ )
observation = InternalErrorObservation(
- error=error_string, title=error_title)
+ error=error_string, title=error_title
+ )
# Reveal this step, but hide all of the following steps (its substeps)
step_was_hidden = step.hide
@@ -331,8 +417,10 @@ class Autopilot(ContinueBaseModel):
# Update subscribers with new description
await self.update_subscribers()
- create_async_task(update_description(
- ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)))
+ create_async_task(
+ update_description(),
+ on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)),
+ )
return observation
@@ -384,17 +472,22 @@ class Autopilot(ContinueBaseModel):
# Use the first input to create title for session info, and make the session saveable
if self.session_info is None:
+
async def create_title():
- title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ")
+ title = await self.continue_sdk.models.medium.complete(
+ f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". The title is: '
+ )
title = remove_quotes_and_escapes(title)
self.session_info = SessionInfo(
title=title,
session_id=self.ide.session_id,
- date_created=str(time.time())
+ date_created=str(time.time()),
)
- create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step(
- DisplayErrorStep(e=e)))
+ create_async_task(
+ create_title(),
+ on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)),
+ )
if len(self._main_user_input_queue) > 1:
return
@@ -407,8 +500,9 @@ class Autopilot(ContinueBaseModel):
await self.run_from_step(UserInputStep(user_input=user_input))
while len(self._main_user_input_queue) > 0:
- await self.run_from_step(UserInputStep(
- user_input=self._main_user_input_queue.pop(0)))
+ await self.run_from_step(
+ UserInputStep(user_input=self._main_user_input_queue.pop(0))
+ )
async def accept_refinement_input(self, user_input: str, index: int):
await self._request_halt()
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 024d5cea..778f81b3 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,23 +1,36 @@
+import os
import traceback
from typing import Coroutine, Union
-import os
-import importlib
-from ..plugins.steps.core.core import DefaultModelEditCodeStep
+from ..libs.llm import LLM
+from ..libs.util.logging import logger
+from ..libs.util.paths import getConfigFilePath
+from ..libs.util.telemetry import posthog_logger
+from ..models.filesystem import RangeInFile
+from ..models.filesystem_edit import (
+ AddDirectory,
+ AddFile,
+ DeleteDirectory,
+ DeleteFile,
+ FileEdit,
+ FileSystemEdit,
+)
from ..models.main import Range
+from ..plugins.steps.core.core import *
+from ..plugins.steps.core.core import DefaultModelEditCodeStep
+from ..server.ide_protocol import AbstractIdeProtocolServer
from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
-from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
-from ..models.filesystem import RangeInFile
-from ..libs.llm import LLM
-from .observation import Observation
-from ..server.ide_protocol import AbstractIdeProtocolServer
-from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
-from ..plugins.steps.core.core import *
-from ..libs.util.telemetry import posthog_logger
-from ..libs.util.paths import getConfigFilePath
+from .main import (
+ ChatMessage,
+ Context,
+ ContinueCustomException,
+ History,
+ HistoryNode,
+ Step,
+)
from .models import Models
-from ..libs.util.logging import logger
+from .observation import Observation
class Autopilot:
@@ -26,6 +39,7 @@ class Autopilot:
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
+
ide: AbstractIdeProtocolServer
models: Models
context: Context
@@ -46,30 +60,29 @@ class ContinueSDK(AbstractContinueSDK):
config = sdk._load_config_dot_py()
sdk.config = config
except Exception as e:
- logger.error(
- f"Failed to load config.py: {traceback.format_exception(e)}")
+ logger.error(f"Failed to load config.py: {traceback.format_exception(e)}")
- sdk.config = ContinueConfig(
- ) if sdk._last_valid_config is None else sdk._last_valid_config
+ sdk.config = (
+ ContinueConfig()
+ if sdk._last_valid_config is None
+ else sdk._last_valid_config
+ )
- formatted_err = '\n'.join(traceback.format_exception(e))
+ formatted_err = "\n".join(traceback.format_exception(e))
msg_step = MessageStep(
- name="Invalid Continue Config File", message=formatted_err)
+ name="Invalid Continue Config File", message=formatted_err
+ )
msg_step.description = f"Falling back to default config settings due to the following error in `~/.continue/config.py`.\n```\n{formatted_err}\n```\n\nIt's possible this was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py)."
- sdk.history.add_node(HistoryNode(
- step=msg_step,
- observation=None,
- depth=0,
- active=False
- ))
+ sdk.history.add_node(
+ HistoryNode(step=msg_step, observation=None, depth=0, active=False)
+ )
await sdk.ide.setFileOpen(getConfigFilePath())
sdk.models = sdk.config.models
await sdk.models.start(sdk)
# When the config is loaded, setup posthog logger
- posthog_logger.setup(
- sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
+ posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
return sdk
@@ -109,8 +122,14 @@ class ContinueSDK(AbstractContinueSDK):
async def run_step(self, step: Step) -> Coroutine[Observation, None, None]:
return await self.__autopilot._run_singular_step(step)
- async def apply_filesystem_edit(self, edit: FileSystemEdit, name: str = None, description: str = None):
- return await self.run_step(FileSystemEditStep(edit=edit, description=description, **({'name': name} if name else {})))
+ async def apply_filesystem_edit(
+ self, edit: FileSystemEdit, name: str = None, description: str = None
+ ):
+ return await self.run_step(
+ FileSystemEditStep(
+ edit=edit, description=description, **({"name": name} if name else {})
+ )
+ )
async def wait_for_user_input(self) -> str:
return await self.__autopilot.wait_for_user_input()
@@ -118,22 +137,51 @@ class ContinueSDK(AbstractContinueSDK):
async def wait_for_user_confirmation(self, prompt: str):
return await self.run_step(WaitForUserConfirmationStep(prompt=prompt))
- async def run(self, commands: Union[List[str], str], cwd: str = None, name: str = None, description: str = None, handle_error: bool = True) -> Coroutine[str, None, None]:
+ async def run(
+ self,
+ commands: Union[List[str], str],
+ cwd: str = None,
+ name: str = None,
+ description: str = None,
+ handle_error: bool = True,
+ ) -> Coroutine[str, None, None]:
commands = commands if isinstance(commands, List) else [commands]
- return (await self.run_step(ShellCommandsStep(cmds=commands, cwd=cwd, description=description, handle_error=handle_error, **({'name': name} if name else {})))).text
-
- async def edit_file(self, filename: str, prompt: str, name: str = None, description: str = "", range: Range = None):
+ return (
+ await self.run_step(
+ ShellCommandsStep(
+ cmds=commands,
+ cwd=cwd,
+ description=description,
+ handle_error=handle_error,
+ **({"name": name} if name else {}),
+ )
+ )
+ ).text
+
+ async def edit_file(
+ self,
+ filename: str,
+ prompt: str,
+ name: str = None,
+ description: str = "",
+ range: Range = None,
+ ):
filepath = await self._ensure_absolute_path(filename)
await self.ide.setFileOpen(filepath)
contents = await self.ide.readFile(filepath)
- await self.run_step(DefaultModelEditCodeStep(
- range_in_files=[RangeInFile(filepath=filepath, range=range) if range is not None else RangeInFile.from_entire_file(
- filepath, contents)],
- user_input=prompt,
- description=description,
- **({'name': name} if name else {})
- ))
+ await self.run_step(
+ DefaultModelEditCodeStep(
+ range_in_files=[
+ RangeInFile(filepath=filepath, range=range)
+ if range is not None
+ else RangeInFile.from_entire_file(filepath, contents)
+ ],
+ user_input=prompt,
+ description=description,
+ **({"name": name} if name else {}),
+ )
+ )
async def append_to_file(self, filename: str, content: str):
filepath = await self._ensure_absolute_path(filename)
@@ -145,11 +193,15 @@ class ContinueSDK(AbstractContinueSDK):
filepath = await self._ensure_absolute_path(filename)
dir_name = os.path.dirname(filepath)
os.makedirs(dir_name, exist_ok=True)
- return await self.run_step(FileSystemEditStep(edit=AddFile(filepath=filepath, content=content)))
+ return await self.run_step(
+ FileSystemEditStep(edit=AddFile(filepath=filepath, content=content))
+ )
async def delete_file(self, filename: str):
filename = await self._ensure_absolute_path(filename)
- return await self.run_step(FileSystemEditStep(edit=DeleteFile(filepath=filename)))
+ return await self.run_step(
+ FileSystemEditStep(edit=DeleteFile(filepath=filename))
+ )
async def add_directory(self, path: str):
path = await self._ensure_absolute_path(path)
@@ -170,6 +222,7 @@ class ContinueSDK(AbstractContinueSDK):
path = getConfigFilePath()
import importlib.util
+
spec = importlib.util.spec_from_file_location("config", path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
@@ -177,24 +230,34 @@ class ContinueSDK(AbstractContinueSDK):
return config.config
- def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
+ def get_code_context(
+ self, only_editing: bool = False
+ ) -> List[RangeInFileWithContents]:
highlighted_ranges = self.__autopilot.context_manager.context_providers[
- "code"].highlighted_ranges
- context = list(filter(lambda x: x.item.editing, highlighted_ranges)
- ) if only_editing else highlighted_ranges
+ "code"
+ ].highlighted_ranges
+ context = (
+ list(filter(lambda x: x.item.editing, highlighted_ranges))
+ if only_editing
+ else highlighted_ranges
+ )
return [c.rif for c in context]
def set_loading_message(self, message: str):
# self.__autopilot.set_loading_message(message)
raise NotImplementedError()
- def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None):
+ def raise_exception(
+ self, message: str, title: str, with_step: Union[Step, None] = None
+ ):
raise ContinueCustomException(message, title, with_step)
async def get_chat_context(self) -> List[ChatMessage]:
history_context = self.history.to_chat_history()
- context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages()
+ context_messages: List[
+ ChatMessage
+ ] = await self.__autopilot.context_manager.get_chat_messages()
# Insert at the end, but don't insert after latest user message or function call
for msg in context_messages:
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 49d46be3..7497e777 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,19 +1,20 @@
import asyncio
import json
-from fastapi import Depends, Header, WebSocket, APIRouter
-from starlette.websockets import WebSocketState, WebSocketDisconnect
+import traceback
from typing import Any, List, Optional, Type, TypeVar
+
+from fastapi import APIRouter, Depends, WebSocket
from pydantic import BaseModel
-import traceback
+from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
-from .session_manager import session_manager, Session
-from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
-from .gui_protocol import AbstractGUIProtocolServer
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..libs.util.telemetry import posthog_logger
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
+from ..libs.util.queue import AsyncSubscriptionQueue
+from ..libs.util.telemetry import posthog_logger
+from ..plugins.steps.core.core import DisplayErrorStep
+from .gui_protocol import AbstractGUIProtocolServer
+from .session_manager import Session, session_manager
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -54,19 +55,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
async def _send_json(self, message_type: str, data: Any):
if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
- await self.websocket.send_json({
- "messageType": message_type,
- "data": data
- })
+ await self.websocket.send_json({"messageType": message_type, "data": data})
async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
try:
- return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ return await asyncio.wait_for(
+ self.sub_queue.get(message_type), timeout=timeout
+ )
except asyncio.TimeoutError:
- raise Exception(
- "GUI Protocol _receive_json timed out after 20 seconds")
+ raise Exception("GUI Protocol _receive_json timed out after 20 seconds")
- async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
+ async def _send_and_receive_json(
+ self, data: Any, resp_model: Type[T], message_type: str
+ ) -> T:
await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
@@ -101,78 +102,95 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.select_context_item(data["id"], data["query"])
elif message_type == "load_session":
self.load_session(data.get("session_id", None))
+ elif message_type == "edit_step_at_index":
+ self.edit_step_at_index(data.get("user_input", ""), data["index"])
def on_main_input(self, input: str):
# Do something with user input
create_async_task(
- self.session.autopilot.accept_user_input(input), self.on_error)
+ self.session.autopilot.accept_user_input(input), self.on_error
+ )
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- create_async_task(
- self.session.autopilot.reverse_to_index(index), self.on_error)
+ create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error)
def on_step_user_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.give_user_input(input, index), self.on_error)
+ self.session.autopilot.give_user_input(input, index), self.on_error
+ )
def on_refinement_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.accept_refinement_input(input, index), self.on_error)
+ self.session.autopilot.accept_refinement_input(input, index), self.on_error
+ )
def on_retry_at_index(self, index: int):
- create_async_task(
- self.session.autopilot.retry_at_index(index), self.on_error)
+ create_async_task(self.session.autopilot.retry_at_index(index), self.on_error)
def on_clear_history(self):
- create_async_task(
- self.session.autopilot.clear_history(), self.on_error)
+ create_async_task(self.session.autopilot.clear_history(), self.on_error)
def on_delete_at_index(self, index: int):
+ create_async_task(self.session.autopilot.delete_at_index(index), self.on_error)
+
+ def edit_step_at_index(self, user_input: str, index: int):
create_async_task(
- self.session.autopilot.delete_at_index(index), self.on_error)
+ self.session.autopilot.edit_step_at_index(user_input, index),
+ self.on_error,
+ )
def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
- self.session.autopilot.delete_context_with_ids(ids), self.on_error)
+ self.session.autopilot.delete_context_with_ids(ids), self.on_error
+ )
def on_toggle_adding_highlighted_code(self):
create_async_task(
- self.session.autopilot.toggle_adding_highlighted_code(), self.on_error)
+ self.session.autopilot.toggle_adding_highlighted_code(), self.on_error
+ )
posthog_logger.capture_event("toggle_adding_highlighted_code", {})
def on_set_editing_at_ids(self, ids: List[str]):
- create_async_task(
- self.session.autopilot.set_editing_at_ids(ids), self.on_error)
+ create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error)
def on_show_logs_at_index(self, index: int):
- name = f"continue_logs.txt"
+ name = "continue_logs.txt"
logs = "\n\n############################################\n\n".join(
- ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
+ [
+ "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"
+ ]
+ + self.session.autopilot.continue_sdk.history.timeline[index].logs
+ )
create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error)
+ self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error
+ )
posthog_logger.capture_event("show_logs_at_index", {})
def select_context_item(self, id: str, query: str):
"""Called when user selects an item from the dropdown"""
create_async_task(
- self.session.autopilot.select_context_item(id, query), self.on_error)
+ self.session.autopilot.select_context_item(id, query), self.on_error
+ )
def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
- new_session_id = await session_manager.load_session(self.session.session_id, session_id)
- await self._send_json("reconnect_at_session", {"session_id": new_session_id})
+ new_session_id = await session_manager.load_session(
+ self.session.session_id, session_id
+ )
+ await self._send_json(
+ "reconnect_at_session", {"session_id": new_session_id}
+ )
- create_async_task(
- load_and_tell_to_reconnect(), self.on_error)
+ create_async_task(load_and_tell_to_reconnect(), self.on_error)
- posthog_logger.capture_event("load_session", {
- "session_id": session_id
- })
+ posthog_logger.capture_event("load_session", {"session_id": session_id})
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
+async def websocket_endpoint(
+ websocket: WebSocket, session: Session = Depends(websocket_session)
+):
try:
logger.debug(f"Received websocket connection at url: {websocket.url}")
await websocket.accept()
@@ -197,14 +215,16 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
- except WebSocketDisconnect as e:
+ except WebSocketDisconnect:
logger.debug("GUI websocket disconnected")
except Exception as e:
# Log, send to PostHog, and send to GUI
logger.debug(f"ERROR in gui websocket: {e}")
- err_msg = '\n'.join(traceback.format_exception(e))
- posthog_logger.capture_event("gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": err_msg})
+ err_msg = "\n".join(traceback.format_exception(e))
+ posthog_logger.capture_event(
+ "gui_error",
+ {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg},
+ )
await session.autopilot.ide.showMessage(err_msg)