diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-17 22:19:55 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-17 22:19:55 -0700 |
commit | c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5 (patch) | |
tree | 0bb5bb11cf185a575df520908078d310775f4a7c /continuedev/src/continuedev/server | |
parent | 98047dc32a5bfa525eff7089b2ac020ce761f9a9 (diff) | |
download | sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.gz sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.bz2 sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.zip |
feat: :sparkles: edit previous inputs
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 112 |
1 files changed, 66 insertions, 46 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 49d46be3..7497e777 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,19 +1,20 @@ import asyncio import json -from fastapi import Depends, Header, WebSocket, APIRouter -from starlette.websockets import WebSocketState, WebSocketDisconnect +import traceback from typing import Any, List, Optional, Type, TypeVar + +from fastapi import APIRouter, Depends, WebSocket from pydantic import BaseModel -import traceback +from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server -from .session_manager import session_manager, Session -from ..plugins.steps.core.core import DisplayErrorStep, MessageStep -from .gui_protocol import AbstractGUIProtocolServer -from ..libs.util.queue import AsyncSubscriptionQueue -from ..libs.util.telemetry import posthog_logger from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger +from ..libs.util.queue import AsyncSubscriptionQueue +from ..libs.util.telemetry import posthog_logger +from ..plugins.steps.core.core import DisplayErrorStep +from .gui_protocol import AbstractGUIProtocolServer +from .session_manager import Session, session_manager router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,19 +55,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: return - await self.websocket.send_json({ - "messageType": message_type, - "data": data - }) + await self.websocket.send_json({"messageType": message_type, "data": data}) async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: try: - return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + return await asyncio.wait_for( + self.sub_queue.get(message_type), timeout=timeout + ) except asyncio.TimeoutError: - raise Exception( - "GUI Protocol _receive_json timed out after 20 seconds") + raise Exception("GUI Protocol _receive_json timed out after 20 seconds") - async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: + async def _send_and_receive_json( + self, data: Any, resp_model: Type[T], message_type: str + ) -> T: await self._send_json(message_type, data) resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) @@ -101,78 +102,95 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": self.load_session(data.get("session_id", None)) + elif message_type == "edit_step_at_index": + self.edit_step_at_index(data.get("user_input", ""), data["index"]) def on_main_input(self, input: str): # Do something with user input create_async_task( - self.session.autopilot.accept_user_input(input), self.on_error) + self.session.autopilot.accept_user_input(input), self.on_error + ) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - create_async_task( - self.session.autopilot.reverse_to_index(index), self.on_error) + create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error) def on_step_user_input(self, input: str, index: int): create_async_task( - self.session.autopilot.give_user_input(input, index), self.on_error) + self.session.autopilot.give_user_input(input, index), self.on_error + ) def on_refinement_input(self, input: str, index: int): create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.on_error) + self.session.autopilot.accept_refinement_input(input, index), self.on_error + ) def on_retry_at_index(self, index: int): - create_async_task( - self.session.autopilot.retry_at_index(index), self.on_error) + create_async_task(self.session.autopilot.retry_at_index(index), self.on_error) def on_clear_history(self): - create_async_task( - self.session.autopilot.clear_history(), self.on_error) + create_async_task(self.session.autopilot.clear_history(), self.on_error) def on_delete_at_index(self, index: int): + create_async_task(self.session.autopilot.delete_at_index(index), self.on_error) + + def edit_step_at_index(self, user_input: str, index: int): create_async_task( - self.session.autopilot.delete_at_index(index), self.on_error) + self.session.autopilot.edit_step_at_index(user_input, index), + self.on_error, + ) def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_with_ids(ids), self.on_error) + self.session.autopilot.delete_context_with_ids(ids), self.on_error + ) def on_toggle_adding_highlighted_code(self): create_async_task( - self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) + self.session.autopilot.toggle_adding_highlighted_code(), self.on_error + ) posthog_logger.capture_event("toggle_adding_highlighted_code", {}) def on_set_editing_at_ids(self, ids: List[str]): - create_async_task( - self.session.autopilot.set_editing_at_ids(ids), self.on_error) + create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error) def on_show_logs_at_index(self, index: int): - name = f"continue_logs.txt" + name = "continue_logs.txt" logs = "\n\n############################################\n\n".join( - ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) + [ + "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" + ] + + self.session.autopilot.continue_sdk.history.timeline[index].logs + ) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) + self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error + ) posthog_logger.capture_event("show_logs_at_index", {}) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( - self.session.autopilot.select_context_item(id, query), self.on_error) + self.session.autopilot.select_context_item(id, query), self.on_error + ) def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): - new_session_id = await session_manager.load_session(self.session.session_id, session_id) - await self._send_json("reconnect_at_session", {"session_id": new_session_id}) + new_session_id = await session_manager.load_session( + self.session.session_id, session_id + ) + await self._send_json( + "reconnect_at_session", {"session_id": new_session_id} + ) - create_async_task( - load_and_tell_to_reconnect(), self.on_error) + create_async_task(load_and_tell_to_reconnect(), self.on_error) - posthog_logger.capture_event("load_session", { - "session_id": session_id - }) + posthog_logger.capture_event("load_session", {"session_id": session_id}) @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): +async def websocket_endpoint( + websocket: WebSocket, session: Session = Depends(websocket_session) +): try: logger.debug(f"Received websocket connection at url: {websocket.url}") await websocket.accept() @@ -197,14 +215,16 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we data = message["data"] protocol.handle_json(message_type, data) - except WebSocketDisconnect as e: + except WebSocketDisconnect: logger.debug("GUI websocket disconnected") except Exception as e: # Log, send to PostHog, and send to GUI logger.debug(f"ERROR in gui websocket: {e}") - err_msg = '\n'.join(traceback.format_exception(e)) - posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + err_msg = "\n".join(traceback.format_exception(e)) + posthog_logger.capture_event( + "gui_error", + {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg}, + ) await session.autopilot.ide.showMessage(err_msg) |