diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-17 22:19:55 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-17 22:19:55 -0700 | 
| commit | c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5 (patch) | |
| tree | 0bb5bb11cf185a575df520908078d310775f4a7c | |
| parent | 98047dc32a5bfa525eff7089b2ac020ce761f9a9 (diff) | |
| download | sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.gz sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.bz2 sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.zip | |
feat: :sparkles: edit previous inputs
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 214 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 159 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 112 | ||||
| -rw-r--r-- | extension/react-app/src/components/UserInputContainer.tsx | 147 | ||||
| -rw-r--r-- | extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts | 2 | ||||
| -rw-r--r-- | extension/react-app/src/hooks/ContinueGUIClientProtocol.ts | 7 | ||||
| -rw-r--r-- | extension/react-app/src/pages/gui.tsx | 1 | 
7 files changed, 469 insertions, 173 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 2c58c6f4..f7808335 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,29 +1,44 @@ -from functools import cached_property -import traceback  import time -from typing import Callable, Coroutine, Dict, List, Optional, Union +import traceback +from functools import cached_property +from typing import Callable, Coroutine, Dict, List, Optional +  from aiohttp import ClientPayloadError +from openai import error as openai_errors  from pydantic import root_validator +from ..libs.util.create_async_task import create_async_task +from ..libs.util.logging import logger +from ..libs.util.queue import AsyncSubscriptionQueue  from ..libs.util.strings import remove_quotes_and_escapes +from ..libs.util.telemetry import posthog_logger +from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback  from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents -from .observation import Observation, InternalErrorObservation -from .context import ContextManager -from ..plugins.policies.default import DefaultPolicy +from ..models.main import ContinueBaseModel  from ..plugins.context_providers.file import FileContextProvider  from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider +from ..plugins.policies.default import DefaultPolicy +from ..plugins.steps.core.core import ( +    DisplayErrorStep, +    ManualEditStep, +    ReversibleStep, +    UserInputStep, +)  from ..server.ide_protocol import AbstractIdeProtocolServer -from ..libs.util.queue import AsyncSubscriptionQueue -from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, Policy, History, FullState, SessionInfo, Step, HistoryNode -from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep +from .context import ContextManager +from .main import ( +    Context, +    ContinueCustomException, +    FullState, +    History, +    HistoryNode, +    Policy, +    SessionInfo, +    Step, +) +from .observation import InternalErrorObservation, Observation  from .sdk import ContinueSDK -from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback -from openai import error as openai_errors -from ..libs.util.create_async_task import create_async_task -from ..libs.util.telemetry import posthog_logger -from ..libs.util.logging import logger  def get_error_title(e: Exception) -> str: @@ -33,18 +48,23 @@ def get_error_title(e: Exception) -> str:          return "This OpenAI API key has been rate limited. Please try again."      elif isinstance(e, openai_errors.Timeout):          return "OpenAI timed out. Please try again." -    elif isinstance(e, openai_errors.InvalidRequestError) and e.code == "context_length_exceeded": +    elif ( +        isinstance(e, openai_errors.InvalidRequestError) +        and e.code == "context_length_exceeded" +    ):          return e._message      elif isinstance(e, ClientPayloadError):          return "The request to OpenAI failed. Please try again."      elif isinstance(e, openai_errors.APIConnectionError): -        return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" +        return 'The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'      elif isinstance(e, openai_errors.InvalidRequestError): -        return 'Invalid request sent to OpenAI. Please try again.' +        return "Invalid request sent to OpenAI. Please try again."      elif "rate_limit_ip_middleware" in e.__str__(): -        return 'You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings.' +        return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings."      elif e.__str__().startswith("Cannot connect to host"): -        return "The request failed. Please check your internet connection and try again." +        return ( +            "The request failed. Please check your internet connection and try again." +        )      return e.__str__() or e.__repr__() @@ -78,10 +98,13 @@ class Autopilot(ContinueBaseModel):          # Load documents into the search index          logger.debug("Starting context manager")          await self.context_manager.start( -            self.continue_sdk.config.context_providers + [ +            self.continue_sdk.config.context_providers +            + [                  HighlightedCodeContextProvider(ide=self.ide), -                FileContextProvider(workspace_dir=self.ide.workspace_directory) -            ], self.ide.workspace_directory) +                FileContextProvider(workspace_dir=self.ide.workspace_directory), +            ], +            self.ide.workspace_directory, +        )          if full_state is not None:              self.history = full_state.history @@ -95,9 +118,9 @@ class Autopilot(ContinueBaseModel):      @root_validator(pre=True)      def fill_in_values(cls, values): -        full_state: FullState = values.get('full_state') +        full_state: FullState = values.get("full_state")          if full_state is not None: -            values['history'] = full_state.history +            values["history"] = full_state.history          return values      async def get_full_state(self) -> FullState: @@ -107,18 +130,37 @@ class Autopilot(ContinueBaseModel):              user_input_queue=self._main_user_input_queue,              slash_commands=self.get_available_slash_commands(),              adding_highlighted_code=self.context_manager.context_providers[ -                "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False, -            selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], -            session_info=self.session_info +                "code" +            ].adding_highlighted_code +            if "code" in self.context_manager.context_providers +            else False, +            selected_context_items=await self.context_manager.get_selected_items() +            if self.context_manager is not None +            else [], +            session_info=self.session_info,          )          self.full_state = full_state          return full_state      def get_available_slash_commands(self) -> List[Dict]: -        custom_commands = list(map(lambda x: { -                               "name": x.name, "description": x.description}, self.continue_sdk.config.custom_commands)) or [] -        slash_commands = list(map(lambda x: { -                              "name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] +        custom_commands = ( +            list( +                map( +                    lambda x: {"name": x.name, "description": x.description}, +                    self.continue_sdk.config.custom_commands, +                ) +            ) +            or [] +        ) +        slash_commands = ( +            list( +                map( +                    lambda x: {"name": x.name, "description": x.description}, +                    self.continue_sdk.config.slash_commands, +                ) +            ) +            or [] +        )          return custom_commands + slash_commands      async def clear_history(self): @@ -182,13 +224,16 @@ class Autopilot(ContinueBaseModel):                      step = tb_step.step({"output": output, **tb_step.params})                      await self._run_singular_step(step) -    async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): +    async def handle_highlighted_code( +        self, range_in_files: List[RangeInFileWithContents] +    ):          if "code" not in self.context_manager.context_providers:              return          # Add to context manager          await self.context_manager.context_providers["code"].handle_highlighted_code( -            range_in_files) +            range_in_files +        )          await self.update_subscribers() @@ -205,6 +250,23 @@ class Autopilot(ContinueBaseModel):          await self.update_subscribers() +    async def edit_step_at_index(self, user_input: str, index: int): +        step_to_rerun = self.history.timeline[index].step.copy() +        step_to_rerun.user_input = user_input + +        # Halt the agent's currently running jobs (delete them) +        while len(self.history.timeline) > index: +            # Remove from timeline +            node_to_delete = self.history.timeline.pop() +            # Delete so it is stopped if in the middle of running +            node_to_delete.deleted = True + +        self.history.current_index = index - 1 +        await self.update_subscribers() + +        # Rerun from the current step +        await self.run_from_step(step_to_rerun) +      async def delete_context_with_ids(self, ids: List[str]):          await self.context_manager.delete_context_with_ids(ids)          await self.update_subscribers() @@ -213,7 +275,11 @@ class Autopilot(ContinueBaseModel):          if "code" not in self.context_manager.context_providers:              return -        self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code +        self.context_manager.context_providers[ +            "code" +        ].adding_highlighted_code = not self.context_manager.context_providers[ +            "code" +        ].adding_highlighted_code          await self.update_subscribers()      async def set_editing_at_ids(self, ids: List[str]): @@ -223,7 +289,9 @@ class Autopilot(ContinueBaseModel):          await self.context_manager.context_providers["code"].set_editing_at_ids(ids)          await self.update_subscribers() -    async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: +    async def _run_singular_step( +        self, step: "Step", is_future_step: bool = False +    ) -> Coroutine[Observation, None, None]:          # Allow config to set disallowed steps          if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps:              return None @@ -239,19 +307,22 @@ class Autopilot(ContinueBaseModel):          #     i -= 1          posthog_logger.capture_event( -            'step run', {'step_name': step.name, 'params': step.dict()}) +            "step run", {"step_name": step.name, "params": step.dict()} +        )          if not is_future_step:              # Check manual edits buffer, clear out if needed by creating a ManualEditStep              if len(self._manual_edits_buffer) > 0:                  manualEditsStep = ManualEditStep.from_sequence( -                    self._manual_edits_buffer) +                    self._manual_edits_buffer +                )                  self._manual_edits_buffer = []                  await self._run_singular_step(manualEditsStep)          # Update history - do this first so we get top-first tree ordering -        index_of_history_node = self.history.add_node(HistoryNode( -            step=step, observation=None, depth=self._step_depth)) +        index_of_history_node = self.history.add_node( +            HistoryNode(step=step, observation=None, depth=self._step_depth) +        )          # Call all subscribed callbacks          await self.update_subscribers() @@ -263,28 +334,43 @@ class Autopilot(ContinueBaseModel):          try:              observation = await step(self.continue_sdk)          except Exception as e: -            if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted: +            if ( +                index_of_history_node >= len(self.history.timeline) +                or self.history.timeline[index_of_history_node].deleted +            ):                  # If step was deleted/cancelled, don't show error or allow retry                  return None              caught_error = True              is_continue_custom_exception = issubclass( -                e.__class__, ContinueCustomException) - -            error_string = e.message if is_continue_custom_exception else '\n'.join( -                traceback.format_exception(e)) -            error_title = e.title if is_continue_custom_exception else get_error_title( -                e) +                e.__class__, ContinueCustomException +            ) + +            error_string = ( +                e.message +                if is_continue_custom_exception +                else "\n".join(traceback.format_exception(e)) +            ) +            error_title = ( +                e.title if is_continue_custom_exception else get_error_title(e) +            )              # Attach an InternalErrorObservation to the step and unhide it. -            logger.error( -                f"Error while running step: \n{error_string}\n{error_title}") -            posthog_logger.capture_event('step error', { -                                         'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) +            logger.error(f"Error while running step: \n{error_string}\n{error_title}") +            posthog_logger.capture_event( +                "step error", +                { +                    "error_message": error_string, +                    "error_title": error_title, +                    "step_name": step.name, +                    "params": step.dict(), +                }, +            )              observation = InternalErrorObservation( -                error=error_string, title=error_title) +                error=error_string, title=error_title +            )              # Reveal this step, but hide all of the following steps (its substeps)              step_was_hidden = step.hide @@ -331,8 +417,10 @@ class Autopilot(ContinueBaseModel):              # Update subscribers with new description              await self.update_subscribers() -        create_async_task(update_description( -        ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e))) +        create_async_task( +            update_description(), +            on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), +        )          return observation @@ -384,17 +472,22 @@ class Autopilot(ContinueBaseModel):          # Use the first input to create title for session info, and make the session saveable          if self.session_info is None: +              async def create_title(): -                title = await self.continue_sdk.models.medium.complete(f"Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: \"{user_input}\". The title is: ") +                title = await self.continue_sdk.models.medium.complete( +                    f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". The title is: ' +                )                  title = remove_quotes_and_escapes(title)                  self.session_info = SessionInfo(                      title=title,                      session_id=self.ide.session_id, -                    date_created=str(time.time()) +                    date_created=str(time.time()),                  ) -            create_async_task(create_title(), on_error=lambda e: self.continue_sdk.run_step( -                DisplayErrorStep(e=e))) +            create_async_task( +                create_title(), +                on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), +            )          if len(self._main_user_input_queue) > 1:              return @@ -407,8 +500,9 @@ class Autopilot(ContinueBaseModel):          await self.run_from_step(UserInputStep(user_input=user_input))          while len(self._main_user_input_queue) > 0: -            await self.run_from_step(UserInputStep( -                user_input=self._main_user_input_queue.pop(0))) +            await self.run_from_step( +                UserInputStep(user_input=self._main_user_input_queue.pop(0)) +            )      async def accept_refinement_input(self, user_input: str, index: int):          await self._request_halt() diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 024d5cea..778f81b3 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,23 +1,36 @@ +import os  import traceback  from typing import Coroutine, Union -import os -import importlib -from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..libs.llm import LLM +from ..libs.util.logging import logger +from ..libs.util.paths import getConfigFilePath +from ..libs.util.telemetry import posthog_logger +from ..models.filesystem import RangeInFile +from ..models.filesystem_edit import ( +    AddDirectory, +    AddFile, +    DeleteDirectory, +    DeleteFile, +    FileEdit, +    FileSystemEdit, +)  from ..models.main import Range +from ..plugins.steps.core.core import * +from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..server.ide_protocol import AbstractIdeProtocolServer  from .abstract_sdk import AbstractContinueSDK  from .config import ContinueConfig -from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory -from ..models.filesystem import RangeInFile -from ..libs.llm import LLM -from .observation import Observation -from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage -from ..plugins.steps.core.core import * -from ..libs.util.telemetry import posthog_logger -from ..libs.util.paths import getConfigFilePath +from .main import ( +    ChatMessage, +    Context, +    ContinueCustomException, +    History, +    HistoryNode, +    Step, +)  from .models import Models -from ..libs.util.logging import logger +from .observation import Observation  class Autopilot: @@ -26,6 +39,7 @@ class Autopilot:  class ContinueSDK(AbstractContinueSDK):      """The SDK provided as parameters to a step""" +      ide: AbstractIdeProtocolServer      models: Models      context: Context @@ -46,30 +60,29 @@ class ContinueSDK(AbstractContinueSDK):              config = sdk._load_config_dot_py()              sdk.config = config          except Exception as e: -            logger.error( -                f"Failed to load config.py: {traceback.format_exception(e)}") +            logger.error(f"Failed to load config.py: {traceback.format_exception(e)}") -            sdk.config = ContinueConfig( -            ) if sdk._last_valid_config is None else sdk._last_valid_config +            sdk.config = ( +                ContinueConfig() +                if sdk._last_valid_config is None +                else sdk._last_valid_config +            ) -            formatted_err = '\n'.join(traceback.format_exception(e)) +            formatted_err = "\n".join(traceback.format_exception(e))              msg_step = MessageStep( -                name="Invalid Continue Config File", message=formatted_err) +                name="Invalid Continue Config File", message=formatted_err +            )              msg_step.description = f"Falling back to default config settings due to the following error in `~/.continue/config.py`.\n```\n{formatted_err}\n```\n\nIt's possible this was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py)." -            sdk.history.add_node(HistoryNode( -                step=msg_step, -                observation=None, -                depth=0, -                active=False -            )) +            sdk.history.add_node( +                HistoryNode(step=msg_step, observation=None, depth=0, active=False) +            )              await sdk.ide.setFileOpen(getConfigFilePath())          sdk.models = sdk.config.models          await sdk.models.start(sdk)          # When the config is loaded, setup posthog logger -        posthog_logger.setup( -            sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) +        posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)          return sdk @@ -109,8 +122,14 @@ class ContinueSDK(AbstractContinueSDK):      async def run_step(self, step: Step) -> Coroutine[Observation, None, None]:          return await self.__autopilot._run_singular_step(step) -    async def apply_filesystem_edit(self, edit: FileSystemEdit, name: str = None, description: str = None): -        return await self.run_step(FileSystemEditStep(edit=edit, description=description, **({'name': name} if name else {}))) +    async def apply_filesystem_edit( +        self, edit: FileSystemEdit, name: str = None, description: str = None +    ): +        return await self.run_step( +            FileSystemEditStep( +                edit=edit, description=description, **({"name": name} if name else {}) +            ) +        )      async def wait_for_user_input(self) -> str:          return await self.__autopilot.wait_for_user_input() @@ -118,22 +137,51 @@ class ContinueSDK(AbstractContinueSDK):      async def wait_for_user_confirmation(self, prompt: str):          return await self.run_step(WaitForUserConfirmationStep(prompt=prompt)) -    async def run(self, commands: Union[List[str], str], cwd: str = None, name: str = None, description: str = None, handle_error: bool = True) -> Coroutine[str, None, None]: +    async def run( +        self, +        commands: Union[List[str], str], +        cwd: str = None, +        name: str = None, +        description: str = None, +        handle_error: bool = True, +    ) -> Coroutine[str, None, None]:          commands = commands if isinstance(commands, List) else [commands] -        return (await self.run_step(ShellCommandsStep(cmds=commands, cwd=cwd, description=description, handle_error=handle_error, **({'name': name} if name else {})))).text - -    async def edit_file(self, filename: str, prompt: str, name: str = None, description: str = "", range: Range = None): +        return ( +            await self.run_step( +                ShellCommandsStep( +                    cmds=commands, +                    cwd=cwd, +                    description=description, +                    handle_error=handle_error, +                    **({"name": name} if name else {}), +                ) +            ) +        ).text + +    async def edit_file( +        self, +        filename: str, +        prompt: str, +        name: str = None, +        description: str = "", +        range: Range = None, +    ):          filepath = await self._ensure_absolute_path(filename)          await self.ide.setFileOpen(filepath)          contents = await self.ide.readFile(filepath) -        await self.run_step(DefaultModelEditCodeStep( -            range_in_files=[RangeInFile(filepath=filepath, range=range) if range is not None else RangeInFile.from_entire_file( -                filepath, contents)], -            user_input=prompt, -            description=description, -            **({'name': name} if name else {}) -        )) +        await self.run_step( +            DefaultModelEditCodeStep( +                range_in_files=[ +                    RangeInFile(filepath=filepath, range=range) +                    if range is not None +                    else RangeInFile.from_entire_file(filepath, contents) +                ], +                user_input=prompt, +                description=description, +                **({"name": name} if name else {}), +            ) +        )      async def append_to_file(self, filename: str, content: str):          filepath = await self._ensure_absolute_path(filename) @@ -145,11 +193,15 @@ class ContinueSDK(AbstractContinueSDK):          filepath = await self._ensure_absolute_path(filename)          dir_name = os.path.dirname(filepath)          os.makedirs(dir_name, exist_ok=True) -        return await self.run_step(FileSystemEditStep(edit=AddFile(filepath=filepath, content=content))) +        return await self.run_step( +            FileSystemEditStep(edit=AddFile(filepath=filepath, content=content)) +        )      async def delete_file(self, filename: str):          filename = await self._ensure_absolute_path(filename) -        return await self.run_step(FileSystemEditStep(edit=DeleteFile(filepath=filename))) +        return await self.run_step( +            FileSystemEditStep(edit=DeleteFile(filepath=filename)) +        )      async def add_directory(self, path: str):          path = await self._ensure_absolute_path(path) @@ -170,6 +222,7 @@ class ContinueSDK(AbstractContinueSDK):          path = getConfigFilePath()          import importlib.util +          spec = importlib.util.spec_from_file_location("config", path)          config = importlib.util.module_from_spec(spec)          spec.loader.exec_module(config) @@ -177,24 +230,34 @@ class ContinueSDK(AbstractContinueSDK):          return config.config -    def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: +    def get_code_context( +        self, only_editing: bool = False +    ) -> List[RangeInFileWithContents]:          highlighted_ranges = self.__autopilot.context_manager.context_providers[ -            "code"].highlighted_ranges -        context = list(filter(lambda x: x.item.editing, highlighted_ranges) -                       ) if only_editing else highlighted_ranges +            "code" +        ].highlighted_ranges +        context = ( +            list(filter(lambda x: x.item.editing, highlighted_ranges)) +            if only_editing +            else highlighted_ranges +        )          return [c.rif for c in context]      def set_loading_message(self, message: str):          # self.__autopilot.set_loading_message(message)          raise NotImplementedError() -    def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None): +    def raise_exception( +        self, message: str, title: str, with_step: Union[Step, None] = None +    ):          raise ContinueCustomException(message, title, with_step)      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history() -        context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages() +        context_messages: List[ +            ChatMessage +        ] = await self.__autopilot.context_manager.get_chat_messages()          # Insert at the end, but don't insert after latest user message or function call          for msg in context_messages: diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 49d46be3..7497e777 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,19 +1,20 @@  import asyncio  import json -from fastapi import Depends, Header, WebSocket, APIRouter -from starlette.websockets import WebSocketState, WebSocketDisconnect +import traceback  from typing import Any, List, Optional, Type, TypeVar + +from fastapi import APIRouter, Depends, WebSocket  from pydantic import BaseModel -import traceback +from starlette.websockets import WebSocketDisconnect, WebSocketState  from uvicorn.main import Server -from .session_manager import session_manager, Session -from ..plugins.steps.core.core import DisplayErrorStep, MessageStep -from .gui_protocol import AbstractGUIProtocolServer -from ..libs.util.queue import AsyncSubscriptionQueue -from ..libs.util.telemetry import posthog_logger  from ..libs.util.create_async_task import create_async_task  from ..libs.util.logging import logger +from ..libs.util.queue import AsyncSubscriptionQueue +from ..libs.util.telemetry import posthog_logger +from ..plugins.steps.core.core import DisplayErrorStep +from .gui_protocol import AbstractGUIProtocolServer +from .session_manager import Session, session_manager  router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,19 +55,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):      async def _send_json(self, message_type: str, data: Any):          if self.websocket.application_state == WebSocketState.DISCONNECTED:              return -        await self.websocket.send_json({ -            "messageType": message_type, -            "data": data -        }) +        await self.websocket.send_json({"messageType": message_type, "data": data})      async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:          try: -            return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) +            return await asyncio.wait_for( +                self.sub_queue.get(message_type), timeout=timeout +            )          except asyncio.TimeoutError: -            raise Exception( -                "GUI Protocol _receive_json timed out after 20 seconds") +            raise Exception("GUI Protocol _receive_json timed out after 20 seconds") -    async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: +    async def _send_and_receive_json( +        self, data: Any, resp_model: Type[T], message_type: str +    ) -> T:          await self._send_json(message_type, data)          resp = await self._receive_json(message_type)          return resp_model.parse_obj(resp) @@ -101,78 +102,95 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              self.select_context_item(data["id"], data["query"])          elif message_type == "load_session":              self.load_session(data.get("session_id", None)) +        elif message_type == "edit_step_at_index": +            self.edit_step_at_index(data.get("user_input", ""), data["index"])      def on_main_input(self, input: str):          # Do something with user input          create_async_task( -            self.session.autopilot.accept_user_input(input), self.on_error) +            self.session.autopilot.accept_user_input(input), self.on_error +        )      def on_reverse_to_index(self, index: int):          # Reverse the history to the given index -        create_async_task( -            self.session.autopilot.reverse_to_index(index), self.on_error) +        create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error)      def on_step_user_input(self, input: str, index: int):          create_async_task( -            self.session.autopilot.give_user_input(input, index), self.on_error) +            self.session.autopilot.give_user_input(input, index), self.on_error +        )      def on_refinement_input(self, input: str, index: int):          create_async_task( -            self.session.autopilot.accept_refinement_input(input, index), self.on_error) +            self.session.autopilot.accept_refinement_input(input, index), self.on_error +        )      def on_retry_at_index(self, index: int): -        create_async_task( -            self.session.autopilot.retry_at_index(index), self.on_error) +        create_async_task(self.session.autopilot.retry_at_index(index), self.on_error)      def on_clear_history(self): -        create_async_task( -            self.session.autopilot.clear_history(), self.on_error) +        create_async_task(self.session.autopilot.clear_history(), self.on_error)      def on_delete_at_index(self, index: int): +        create_async_task(self.session.autopilot.delete_at_index(index), self.on_error) + +    def edit_step_at_index(self, user_input: str, index: int):          create_async_task( -            self.session.autopilot.delete_at_index(index), self.on_error) +            self.session.autopilot.edit_step_at_index(user_input, index), +            self.on_error, +        )      def on_delete_context_with_ids(self, ids: List[str]):          create_async_task( -            self.session.autopilot.delete_context_with_ids(ids), self.on_error) +            self.session.autopilot.delete_context_with_ids(ids), self.on_error +        )      def on_toggle_adding_highlighted_code(self):          create_async_task( -            self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) +            self.session.autopilot.toggle_adding_highlighted_code(), self.on_error +        )          posthog_logger.capture_event("toggle_adding_highlighted_code", {})      def on_set_editing_at_ids(self, ids: List[str]): -        create_async_task( -            self.session.autopilot.set_editing_at_ids(ids), self.on_error) +        create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error)      def on_show_logs_at_index(self, index: int): -        name = f"continue_logs.txt" +        name = "continue_logs.txt"          logs = "\n\n############################################\n\n".join( -            ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) +            [ +                "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" +            ] +            + self.session.autopilot.continue_sdk.history.timeline[index].logs +        )          create_async_task( -            self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) +            self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error +        )          posthog_logger.capture_event("show_logs_at_index", {})      def select_context_item(self, id: str, query: str):          """Called when user selects an item from the dropdown"""          create_async_task( -            self.session.autopilot.select_context_item(id, query), self.on_error) +            self.session.autopilot.select_context_item(id, query), self.on_error +        )      def load_session(self, session_id: Optional[str] = None):          async def load_and_tell_to_reconnect(): -            new_session_id = await session_manager.load_session(self.session.session_id, session_id) -            await self._send_json("reconnect_at_session", {"session_id": new_session_id}) +            new_session_id = await session_manager.load_session( +                self.session.session_id, session_id +            ) +            await self._send_json( +                "reconnect_at_session", {"session_id": new_session_id} +            ) -        create_async_task( -            load_and_tell_to_reconnect(), self.on_error) +        create_async_task(load_and_tell_to_reconnect(), self.on_error) -        posthog_logger.capture_event("load_session", { -            "session_id": session_id -        }) +        posthog_logger.capture_event("load_session", {"session_id": session_id})  @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): +async def websocket_endpoint( +    websocket: WebSocket, session: Session = Depends(websocket_session) +):      try:          logger.debug(f"Received websocket connection at url: {websocket.url}")          await websocket.accept() @@ -197,14 +215,16 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we              data = message["data"]              protocol.handle_json(message_type, data) -    except WebSocketDisconnect as e: +    except WebSocketDisconnect:          logger.debug("GUI websocket disconnected")      except Exception as e:          # Log, send to PostHog, and send to GUI          logger.debug(f"ERROR in gui websocket: {e}") -        err_msg = '\n'.join(traceback.format_exception(e)) -        posthog_logger.capture_event("gui_error", { -            "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) +        err_msg = "\n".join(traceback.format_exception(e)) +        posthog_logger.capture_event( +            "gui_error", +            {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg}, +        )          await session.autopilot.ide.showMessage(err_msg) diff --git a/extension/react-app/src/components/UserInputContainer.tsx b/extension/react-app/src/components/UserInputContainer.tsx index 90cd549b..9784a615 100644 --- a/extension/react-app/src/components/UserInputContainer.tsx +++ b/extension/react-app/src/components/UserInputContainer.tsx @@ -1,16 +1,24 @@ -import React, { useState } from "react"; +import React, { useContext, useEffect, useRef, useState } from "react";  import ReactMarkdown from "react-markdown";  import styled from "styled-components"; -import { defaultBorderRadius, secondaryDark, vscBackground } from "."; +import { +  defaultBorderRadius, +  secondaryDark, +  vscBackground, +  vscForeground, +} from ".";  import HeaderButtonWithText from "./HeaderButtonWithText"; -import { XMarkIcon } from "@heroicons/react/24/outline"; +import { XMarkIcon, PencilIcon, CheckIcon } from "@heroicons/react/24/outline";  import { HistoryNode } from "../../../schema/HistoryNode";  import StyledMarkdownPreview from "./StyledMarkdownPreview"; +import { GUIClientContext } from "../App"; +import { text } from "stream/consumers";  interface UserInputContainerProps {    onDelete: () => void;    children: string;    historyNode: HistoryNode; +  index: number;  }  const StyledDiv = styled.div` @@ -40,8 +48,69 @@ const StyledPre = styled.pre`    font-size: 13px;  `; +const TextArea = styled.textarea` +  margin: 8px; +  margin-right: 22px; +  padding: 8px; +  white-space: pre-wrap; +  word-wrap: break-word; +  font-family: "Lexend", sans-serif; +  font-size: 13px; +  width: 100%; +  border-radius: ${defaultBorderRadius}; +  height: 100%; +  border: none; +  background-color: ${vscBackground}; +  resize: none; +  outline: none; +  border: none; +  color: ${vscForeground}; + +  &:focus { +    border: none; +    outline: none; +  } +`; +  const UserInputContainer = (props: UserInputContainerProps) => {    const [isHovered, setIsHovered] = useState(false); +  const [isEditing, setIsEditing] = useState(false); + +  const textAreaRef = useRef<HTMLTextAreaElement>(null); +  const client = useContext(GUIClientContext); + +  useEffect(() => { +    if (isEditing) { +      textAreaRef.current?.focus(); +      // Select all text +      textAreaRef.current?.setSelectionRange( +        0, +        textAreaRef.current.value.length +      ); +    } +  }, [isEditing]); + +  useEffect(() => { +    const handleKeyDown = (event: KeyboardEvent) => { +      if (event.key === "Escape") { +        setIsEditing(false); +      } +    }; +    document.addEventListener("keydown", handleKeyDown); +    return () => { +      document.removeEventListener("keydown", handleKeyDown); +    }; +  }, []); + +  const doneEditing = (e: any) => { +    if (!textAreaRef.current?.value) { +      return; +    } +    client?.editStepAtIndex(textAreaRef.current.value, props.index); +    setIsEditing(false); +    e.stopPropagation(); +  }; +    return (      <StyledDiv        onMouseEnter={() => { @@ -51,24 +120,64 @@ const UserInputContainer = (props: UserInputContainerProps) => {          setIsHovered(false);        }}      > -      {/* <StyledMarkdownPreview -        light={true} -        source={props.children} -        className="mr-6" -      /> */} -      <StyledPre className="mr-6">{props.children}</StyledPre> +      {isEditing ? ( +        <TextArea +          ref={textAreaRef} +          onKeyDown={(e) => { +            if (e.key === "Enter" && !e.shiftKey) { +              e.preventDefault(); +              doneEditing(e); +            } +          }} +          defaultValue={props.children} +        /> +      ) : ( +        <StyledPre +          onClick={() => { +            setIsEditing(true); +          }} +          className="mr-6 cursor-text w-full" +        > +          {props.children} +        </StyledPre> +      )}        {/* <ReactMarkdown children={props.children} className="w-fit mr-10" /> */}        <DeleteButtonDiv> -        {isHovered && ( -          <HeaderButtonWithText -            onClick={(e) => { -              props.onDelete(); -              e.stopPropagation(); -            }} -            text="Delete" -          > -            <XMarkIcon width="1.4em" height="1.4em" /> -          </HeaderButtonWithText> +        {(isHovered || isEditing) && ( +          <div className="flex"> +            {isEditing ? ( +              <HeaderButtonWithText +                onClick={(e) => { +                  doneEditing(e); +                }} +                text="Done" +              > +                <CheckIcon width="1.4em" height="1.4em" /> +              </HeaderButtonWithText> +            ) : ( +              <> +                <HeaderButtonWithText +                  onClick={(e) => { +                    setIsEditing((prev) => !prev); +                    e.stopPropagation(); +                  }} +                  text="Edit" +                > +                  <PencilIcon width="1.4em" height="1.4em" /> +                </HeaderButtonWithText> + +                <HeaderButtonWithText +                  onClick={(e) => { +                    props.onDelete(); +                    e.stopPropagation(); +                  }} +                  text="Delete" +                > +                  <XMarkIcon width="1.4em" height="1.4em" /> +                </HeaderButtonWithText> +              </> +            )} +          </div>          )}        </DeleteButtonDiv>      </StyledDiv> diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index e018c03c..804362aa 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -34,6 +34,8 @@ abstract class AbstractContinueGUIClientProtocol {    abstract loadSession(session_id?: string): void;    abstract onReconnectAtSession(session_id: string): void; + +  abstract editStepAtIndex(userInput: string, index: number): void;  }  export default AbstractContinueGUIClientProtocol; diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index c2285f6d..82aeee28 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -125,6 +125,13 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol {    selectContextItem(id: string, query: string): void {      this.messenger?.send("select_context_item", { id, query });    } + +  editStepAtIndex(userInput: string, index: number): void { +    this.messenger?.send("edit_step_at_index", { +      user_input: userInput, +      index, +    }); +  }  }  export default ContinueGUIClientProtocol; diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index 4c89bbaa..a4a3d379 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -412,6 +412,7 @@ function GUI(props: GUIProps) {          return node.step.name === "User Input" ? (            node.step.hide || (              <UserInputContainer +              index={index}                onDelete={() => {                  client?.deleteAtIndex(index);                }} | 
