summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-17 22:19:55 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-17 22:19:55 -0700
commitc6a3d8add014ddfe08c62b3ccb1b01dbc47495f5 (patch)
tree0bb5bb11cf185a575df520908078d310775f4a7c /continuedev/src/continuedev/server
parent98047dc32a5bfa525eff7089b2ac020ce761f9a9 (diff)
downloadsncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.gz
sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.tar.bz2
sncontinue-c6a3d8add014ddfe08c62b3ccb1b01dbc47495f5.zip
feat: :sparkles: edit previous inputs
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py112
1 files changed, 66 insertions, 46 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 49d46be3..7497e777 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,19 +1,20 @@
import asyncio
import json
-from fastapi import Depends, Header, WebSocket, APIRouter
-from starlette.websockets import WebSocketState, WebSocketDisconnect
+import traceback
from typing import Any, List, Optional, Type, TypeVar
+
+from fastapi import APIRouter, Depends, WebSocket
from pydantic import BaseModel
-import traceback
+from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
-from .session_manager import session_manager, Session
-from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
-from .gui_protocol import AbstractGUIProtocolServer
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..libs.util.telemetry import posthog_logger
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
+from ..libs.util.queue import AsyncSubscriptionQueue
+from ..libs.util.telemetry import posthog_logger
+from ..plugins.steps.core.core import DisplayErrorStep
+from .gui_protocol import AbstractGUIProtocolServer
+from .session_manager import Session, session_manager
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -54,19 +55,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
async def _send_json(self, message_type: str, data: Any):
if self.websocket.application_state == WebSocketState.DISCONNECTED:
return
- await self.websocket.send_json({
- "messageType": message_type,
- "data": data
- })
+ await self.websocket.send_json({"messageType": message_type, "data": data})
async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
try:
- return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ return await asyncio.wait_for(
+ self.sub_queue.get(message_type), timeout=timeout
+ )
except asyncio.TimeoutError:
- raise Exception(
- "GUI Protocol _receive_json timed out after 20 seconds")
+ raise Exception("GUI Protocol _receive_json timed out after 20 seconds")
- async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
+ async def _send_and_receive_json(
+ self, data: Any, resp_model: Type[T], message_type: str
+ ) -> T:
await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
@@ -101,78 +102,95 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.select_context_item(data["id"], data["query"])
elif message_type == "load_session":
self.load_session(data.get("session_id", None))
+ elif message_type == "edit_step_at_index":
+ self.edit_step_at_index(data.get("user_input", ""), data["index"])
def on_main_input(self, input: str):
# Do something with user input
create_async_task(
- self.session.autopilot.accept_user_input(input), self.on_error)
+ self.session.autopilot.accept_user_input(input), self.on_error
+ )
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- create_async_task(
- self.session.autopilot.reverse_to_index(index), self.on_error)
+ create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error)
def on_step_user_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.give_user_input(input, index), self.on_error)
+ self.session.autopilot.give_user_input(input, index), self.on_error
+ )
def on_refinement_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.accept_refinement_input(input, index), self.on_error)
+ self.session.autopilot.accept_refinement_input(input, index), self.on_error
+ )
def on_retry_at_index(self, index: int):
- create_async_task(
- self.session.autopilot.retry_at_index(index), self.on_error)
+ create_async_task(self.session.autopilot.retry_at_index(index), self.on_error)
def on_clear_history(self):
- create_async_task(
- self.session.autopilot.clear_history(), self.on_error)
+ create_async_task(self.session.autopilot.clear_history(), self.on_error)
def on_delete_at_index(self, index: int):
+ create_async_task(self.session.autopilot.delete_at_index(index), self.on_error)
+
+ def edit_step_at_index(self, user_input: str, index: int):
create_async_task(
- self.session.autopilot.delete_at_index(index), self.on_error)
+ self.session.autopilot.edit_step_at_index(user_input, index),
+ self.on_error,
+ )
def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
- self.session.autopilot.delete_context_with_ids(ids), self.on_error)
+ self.session.autopilot.delete_context_with_ids(ids), self.on_error
+ )
def on_toggle_adding_highlighted_code(self):
create_async_task(
- self.session.autopilot.toggle_adding_highlighted_code(), self.on_error)
+ self.session.autopilot.toggle_adding_highlighted_code(), self.on_error
+ )
posthog_logger.capture_event("toggle_adding_highlighted_code", {})
def on_set_editing_at_ids(self, ids: List[str]):
- create_async_task(
- self.session.autopilot.set_editing_at_ids(ids), self.on_error)
+ create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error)
def on_show_logs_at_index(self, index: int):
- name = f"continue_logs.txt"
+ name = "continue_logs.txt"
logs = "\n\n############################################\n\n".join(
- ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
+ [
+ "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"
+ ]
+ + self.session.autopilot.continue_sdk.history.timeline[index].logs
+ )
create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error)
+ self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error
+ )
posthog_logger.capture_event("show_logs_at_index", {})
def select_context_item(self, id: str, query: str):
"""Called when user selects an item from the dropdown"""
create_async_task(
- self.session.autopilot.select_context_item(id, query), self.on_error)
+ self.session.autopilot.select_context_item(id, query), self.on_error
+ )
def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
- new_session_id = await session_manager.load_session(self.session.session_id, session_id)
- await self._send_json("reconnect_at_session", {"session_id": new_session_id})
+ new_session_id = await session_manager.load_session(
+ self.session.session_id, session_id
+ )
+ await self._send_json(
+ "reconnect_at_session", {"session_id": new_session_id}
+ )
- create_async_task(
- load_and_tell_to_reconnect(), self.on_error)
+ create_async_task(load_and_tell_to_reconnect(), self.on_error)
- posthog_logger.capture_event("load_session", {
- "session_id": session_id
- })
+ posthog_logger.capture_event("load_session", {"session_id": session_id})
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
+async def websocket_endpoint(
+ websocket: WebSocket, session: Session = Depends(websocket_session)
+):
try:
logger.debug(f"Received websocket connection at url: {websocket.url}")
await websocket.accept()
@@ -197,14 +215,16 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
- except WebSocketDisconnect as e:
+ except WebSocketDisconnect:
logger.debug("GUI websocket disconnected")
except Exception as e:
# Log, send to PostHog, and send to GUI
logger.debug(f"ERROR in gui websocket: {e}")
- err_msg = '\n'.join(traceback.format_exception(e))
- posthog_logger.capture_event("gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": err_msg})
+ err_msg = "\n".join(traceback.format_exception(e))
+ posthog_logger.capture_event(
+ "gui_error",
+ {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg},
+ )
await session.autopilot.ide.showMessage(err_msg)