summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server/gui.py
blob: 5589284a1692f9cf7850e237efdd50fa0da8a3b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import asyncio
import json
import traceback
from typing import Any, List, Optional, Type, TypeVar

from fastapi import APIRouter, Depends, WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server

from ..core.main import ContextItem
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager

router = APIRouter(prefix="/gui", tags=["gui"])

# Graceful shutdown by closing websockets
original_handler = Server.handle_exit


class AppStatus:
    should_exit = False

    @staticmethod
    def handle_exit(*args, **kwargs):
        AppStatus.should_exit = True
        logger.debug("Shutting down")
        original_handler(*args, **kwargs)


Server.handle_exit = AppStatus.handle_exit


async def websocket_session(session_id: str) -> Session:
    return await session_manager.get_session(session_id)


T = TypeVar("T", bound=BaseModel)

# You should probably abstract away the websocket stuff into a separate class


class GUIProtocolServer(AbstractGUIProtocolServer):
    websocket: WebSocket
    session: Session
    sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()

    def __init__(self, session: Session):
        self.session = session

    async def _send_json(self, message_type: str, data: Any):
        if self.websocket.application_state == WebSocketState.DISCONNECTED:
            return
        await self.websocket.send_json({"messageType": message_type, "data": data})

    async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
        try:
            return await asyncio.wait_for(
                self.sub_queue.get(message_type), timeout=timeout
            )
        except asyncio.TimeoutError:
            raise Exception("GUI Protocol _receive_json timed out after 20 seconds")

    async def _send_and_receive_json(
        self, data: Any, resp_model: Type[T], message_type: str
    ) -> T:
        await self._send_json(message_type, data)
        resp = await self._receive_json(message_type)
        return resp_model.parse_obj(resp)

    def on_error(self, e: Exception):
        return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))

    def handle_json(self, message_type: str, data: Any):
        if message_type == "main_input":
            self.on_main_input(data["input"])
        elif message_type == "step_user_input":
            self.on_step_user_input(data["input"], data["index"])
        elif message_type == "refinement_input":
            self.on_refinement_input(data["input"], data["index"])
        elif message_type == "reverse_to_index":
            self.on_reverse_to_index(data["index"])
        elif message_type == "retry_at_index":
            self.on_retry_at_index(data["index"])
        elif message_type == "clear_history":
            self.on_clear_history()
        elif message_type == "delete_at_index":
            self.on_delete_at_index(data["index"])
        elif message_type == "delete_context_with_ids":
            self.on_delete_context_with_ids(data["ids"])
        elif message_type == "toggle_adding_highlighted_code":
            self.on_toggle_adding_highlighted_code()
        elif message_type == "set_editing_at_ids":
            self.on_set_editing_at_ids(data["ids"])
        elif message_type == "show_logs_at_index":
            self.on_show_logs_at_index(data["index"])
        elif message_type == "select_context_item":
            self.select_context_item(data["id"], data["query"])
        elif message_type == "load_session":
            self.load_session(data.get("session_id", None))
        elif message_type == "edit_step_at_index":
            self.edit_step_at_index(data.get("user_input", ""), data["index"])
        elif message_type == "save_context_group":
            self.save_context_group(
                data["title"], [ContextItem(**item) for item in data["context_items"]]
            )
        elif message_type == "select_context_group":
            self.select_context_group(data["id"])
        elif message_type == "delete_context_group":
            self.delete_context_group(data["id"])

    def on_main_input(self, input: str):
        # Do something with user input
        create_async_task(
            self.session.autopilot.accept_user_input(input), self.on_error
        )

    def on_reverse_to_index(self, index: int):
        # Reverse the history to the given index
        create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error)

    def on_step_user_input(self, input: str, index: int):
        create_async_task(
            self.session.autopilot.give_user_input(input, index), self.on_error
        )

    def on_refinement_input(self, input: str, index: int):
        create_async_task(
            self.session.autopilot.accept_refinement_input(input, index), self.on_error
        )

    def on_retry_at_index(self, index: int):
        create_async_task(self.session.autopilot.retry_at_index(index), self.on_error)

    def on_clear_history(self):
        create_async_task(self.session.autopilot.clear_history(), self.on_error)

    def on_delete_at_index(self, index: int):
        create_async_task(self.session.autopilot.delete_at_index(index), self.on_error)

    def edit_step_at_index(self, user_input: str, index: int):
        create_async_task(
            self.session.autopilot.edit_step_at_index(user_input, index),
            self.on_error,
        )

    def on_delete_context_with_ids(self, ids: List[str]):
        create_async_task(
            self.session.autopilot.delete_context_with_ids(ids), self.on_error
        )

    def on_toggle_adding_highlighted_code(self):
        create_async_task(
            self.session.autopilot.toggle_adding_highlighted_code(), self.on_error
        )
        posthog_logger.capture_event("toggle_adding_highlighted_code", {})

    def on_set_editing_at_ids(self, ids: List[str]):
        create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error)

    def on_show_logs_at_index(self, index: int):
        name = "continue_logs.txt"
        logs = "\n\n############################################\n\n".join(
            [
                "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"
            ]
            + self.session.autopilot.continue_sdk.history.timeline[index].logs
        )
        create_async_task(
            self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error
        )
        posthog_logger.capture_event("show_logs_at_index", {})

    def select_context_item(self, id: str, query: str):
        """Called when user selects an item from the dropdown"""
        create_async_task(
            self.session.autopilot.select_context_item(id, query), self.on_error
        )

    def load_session(self, session_id: Optional[str] = None):
        async def load_and_tell_to_reconnect():
            new_session_id = await session_manager.load_session(
                self.session.session_id, session_id
            )
            await self._send_json(
                "reconnect_at_session", {"session_id": new_session_id}
            )

        create_async_task(load_and_tell_to_reconnect(), self.on_error)

        posthog_logger.capture_event("load_session", {"session_id": session_id})

    def save_context_group(self, title: str, context_items: List[ContextItem]):
        create_async_task(
            self.session.autopilot.save_context_group(title, context_items),
            self.on_error,
        )

    def select_context_group(self, id: str):
        create_async_task(
            self.session.autopilot.select_context_group(id), self.on_error
        )

    def delete_context_group(self, id: str):
        create_async_task(
            self.session.autopilot.delete_context_group(id), self.on_error
        )


@router.websocket("/ws")
async def websocket_endpoint(
    websocket: WebSocket, session: Session = Depends(websocket_session)
):
    try:
        logger.debug(f"Received websocket connection at url: {websocket.url}")
        await websocket.accept()

        logger.debug("Session started")
        session_manager.register_websocket(session.session_id, websocket)
        protocol = GUIProtocolServer(session)
        protocol.websocket = websocket

        # Update any history that may have happened before connection
        await protocol.session.autopilot.update_subscribers()

        while AppStatus.should_exit is False:
            message = await websocket.receive_text()
            logger.debug(f"Received GUI message {message}")
            if type(message) is str:
                message = json.loads(message)

            if "messageType" not in message or "data" not in message:
                continue  # :o
            message_type = message["messageType"]
            data = message["data"]

            protocol.handle_json(message_type, data)
    except WebSocketDisconnect:
        logger.debug("GUI websocket disconnected")
    except Exception as e:
        # Log, send to PostHog, and send to GUI
        logger.debug(f"ERROR in gui websocket: {e}")
        err_msg = "\n".join(traceback.format_exception(e))
        posthog_logger.capture_event(
            "gui_error",
            {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg},
        )

        await session.autopilot.ide.showMessage(err_msg)

        raise e
    finally:
        logger.debug("Closing gui websocket")
        if websocket.client_state != WebSocketState.DISCONNECTED:
            await websocket.close()

        await session_manager.persist_session(session.session_id)
        await session_manager.remove_session(session.session_id)