# This is a separate server from server/main.py import json import os from typing import Any, Coroutine, List, Type, TypeVar, Union import uuid from fastapi import WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect from uvicorn.main import Server from pydantic import BaseModel import traceback import asyncio from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .meilisearch_server import start_meilisearch from ..libs.util.telemetry import posthog_logger from ..libs.util.queue import AsyncSubscriptionQueue from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit from .gui import session_manager from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from .session_manager import SessionManager from ..libs.util.logging import logger import nest_asyncio nest_asyncio.apply() router = APIRouter(prefix="/ide", tags=["ide"]) # Graceful shutdown by closing websockets original_handler = Server.handle_exit class AppStatus: should_exit = False @staticmethod def handle_exit(*args, **kwargs): AppStatus.should_exit = True logger.debug("Shutting down") original_handler(*args, **kwargs) Server.handle_exit = AppStatus.handle_exit # TYPES # class FileEditsUpdate(BaseModel): fileEdits: List[FileEditWithFullContents] class OpenFilesResponse(BaseModel): openFiles: List[str] class VisibleFilesResponse(BaseModel): visibleFiles: List[str] class HighlightedCodeResponse(BaseModel): highlightedCode: List[RangeInFile] class ShowSuggestionRequest(BaseModel): suggestion: FileEdit class ShowSuggestionResponse(BaseModel): suggestion: FileEdit accepted: bool class ReadFileResponse(BaseModel): contents: str class EditFileResponse(BaseModel): fileEdit: FileEditWithFullContents class WorkspaceDirectoryResponse(BaseModel): workspaceDirectory: str class GetUserSecretResponse(BaseModel): value: str class RunCommandResponse(BaseModel): output: str = "" class UniqueIdResponse(BaseModel): uniqueId: str T = TypeVar("T", bound=BaseModel) class cached_property_no_none: def __init__(self, func): self.func = func def __get__(self, instance, owner): if instance is None: return self value = self.func(instance) if value is not None: setattr(instance, self.func.__name__, value) return value def __repr__(self): return f"" class IdeProtocolServer(AbstractIdeProtocolServer): websocket: WebSocket session_manager: SessionManager sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() session_id: Union[str, None] = None def __init__(self, session_manager: SessionManager, websocket: WebSocket): self.websocket = websocket self.session_manager = session_manager workspace_directory: str = None unique_id: str = None async def initialize(self, session_id: str) -> List[str]: self.session_id = session_id await self._send_json("workspaceDirectory", {}) await self._send_json("uniqueId", {}) other_msgs = [] while True: msg_string = await self.websocket.receive_text() message = json.loads(msg_string) if "messageType" not in message or "data" not in message: continue # <-- hey that's the name of this repo! message_type = message["messageType"] data = message["data"] logger.debug(f"Received message while initializing {message_type}") if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": self.unique_id = data["uniqueId"] else: other_msgs.append(msg_string) if self.workspace_directory is not None and self.unique_id is not None: break return other_msgs async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: logger.debug( f"Tried to send message, but websocket is disconnected: {message_type}") return logger.debug(f"Sending IDE message: {message_type}") await self.websocket.send_json({ "messageType": message_type, "data": data }) async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: try: return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) except asyncio.TimeoutError: raise Exception( f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) async def handle_json(self, message_type: str, data: Any): if message_type == "getSessionId": await self.getSessionId() elif message_type == "setFileOpen": await self.setFileOpen(data["filepath"], data["open"]) elif message_type == "setSuggestionsLocked": await self.setSuggestionsLocked(data["filepath"], data["locked"]) elif message_type == "fileEdits": fileEdits = list( map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])) self.onFileEdits(fileEdits) elif message_type == "highlightedCodePush": self.onHighlightedCodeUpdate( [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]]) elif message_type == "commandOutput": output = data["output"] self.onCommandOutput(output) elif message_type == "acceptRejectSuggestion": self.onAcceptRejectSuggestion(data["accepted"]) elif message_type == "acceptRejectDiff": self.onAcceptRejectDiff(data["accepted"]) elif message_type == "mainUserInput": self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]: self.sub_queue.post(message_type, data) elif message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": self.unique_id = data["uniqueId"] else: raise ValueError("Unknown message type", message_type) async def showSuggestion(self, file_edit: FileEdit): await self._send_json("showSuggestion", { "edit": file_edit.dict() }) async def showDiff(self, filepath: str, replacement: str, step_index: int): await self._send_json("showDiff", { "filepath": filepath, "replacement": replacement, "step_index": step_index }) async def setFileOpen(self, filepath: str, open: bool = True): # Autopilot needs access to this. await self._send_json("setFileOpen", { "filepath": filepath, "open": open }) async def showMessage(self, message: str): await self._send_json("showMessage", { "message": message }) async def showVirtualFile(self, name: str, contents: str): await self._send_json("showVirtualFile", { "name": name, "contents": contents }) async def setSuggestionsLocked(self, filepath: str, locked: bool = True): # Lock suggestions in the file so they don't ruin the offset before others are inserted await self._send_json("setSuggestionsLocked", { "filepath": filepath, "locked": locked }) async def getSessionId(self): new_session = await asyncio.wait_for(self.session_manager.new_session( self, self.session_id), timeout=5) session_id = new_session.session_id logger.debug(f"Sending session id: {session_id}") await self._send_json("getSessionId", { "sessionId": session_id }) async def highlightCode(self, range_in_file: RangeInFile, color: str = "#00ff0022"): await self._send_json("highlightCode", { "rangeInFile": range_in_file.dict(), "color": color }) async def runCommand(self, command: str) -> str: return (await self._send_and_receive_json({"command": command}, RunCommandResponse, "runCommand")).output async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: ids = [str(uuid.uuid4()) for _ in suggestions] for i in range(len(suggestions)): self._send_json("showSuggestion", { "suggestion": suggestions[i], "suggestionId": ids[i] }) responses = await asyncio.gather(*[ self._receive_json(ShowSuggestionResponse) for i in range(len(suggestions)) ]) # WORKING ON THIS FLOW HERE. Fine now to just await for response, instead of doing something fancy with a "waiting" state on the autopilot. # Just need connect the suggestionId to the IDE (and the gui) return any([r.accepted for r in responses]) def on_error(self, e: Exception) -> Coroutine: err_msg = '\n'.join(traceback.format_exception(e)) e_title = e.__str__() or e.__repr__() return self.showMessage(f"Error in Continue server: {e_title}\n {err_msg}") def onAcceptRejectSuggestion(self, accepted: bool): posthog_logger.capture_event("accept_reject_suggestion", { "accepted": accepted }) def onAcceptRejectDiff(self, accepted: bool): posthog_logger.capture_event("accept_reject_diff", { "accepted": accepted }) def onFileSystemUpdate(self, update: FileSystemEdit): # Access to Autopilot (so SessionManager) pass def onCloseGUI(self, session_id: str): # Accesss to SessionManager pass def onOpenGUIRequest(self): pass def __get_autopilot(self): if self.session_id not in self.session_manager.sessions: return None autopilot = self.session_manager.sessions[self.session_id].autopilot return autopilot if autopilot.started else None def onFileEdits(self, edits: List[FileEditWithFullContents]): if autopilot := self.__get_autopilot(): pass def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): create_async_task(autopilot.delete_at_index(index), self.on_error) def onCommandOutput(self, output: str): if autopilot := self.__get_autopilot(): create_async_task( autopilot.handle_command_output(output), self.on_error) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): if autopilot := self.__get_autopilot(): create_async_task(autopilot.handle_highlighted_code( range_in_files), self.on_error) def onMainUserInput(self, input: str): if autopilot := self.__get_autopilot(): create_async_task( autopilot.accept_user_input(input), self.on_error) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") return resp.openFiles async def getVisibleFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles") return resp.visibleFiles async def getHighlightedCode(self) -> List[RangeInFile]: resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode") return resp.highlightedCode async def readFile(self, filepath: str) -> str: """Read a file""" resp = await self._send_and_receive_json({ "filepath": filepath }, ReadFileResponse, "readFile") return resp.contents async def getUserSecret(self, key: str) -> str: """Get a user secret""" try: resp = await self._send_and_receive_json({ "key": key }, GetUserSecretResponse, "getUserSecret") return resp.value except Exception as e: logger.debug(f"Error getting user secret: {e}") return "" async def saveFile(self, filepath: str): """Save a file""" await self._send_json("saveFile", { "filepath": filepath }) async def readRangeInFile(self, range_in_file: RangeInFile) -> str: """Read a range in a file""" full_contents = await self.readFile(range_in_file.filepath) return FileSystem.read_range_in_str(full_contents, range_in_file.range) async def editFile(self, edit: FileEdit) -> FileEditWithFullContents: """Edit a file""" resp = await self._send_and_receive_json({ "edit": edit.dict() }, EditFileResponse, "editFile") return resp.fileEdit async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: """Apply a file edit""" backward = None fs = RealFileSystem() if isinstance(edit, FileEdit): file_edit = await self.editFile(edit) _, diff = FileSystem.apply_edit_to_str( file_edit.fileContents, file_edit.fileEdit) backward = diff.backward elif isinstance(edit, AddFile): fs.write(edit.filepath, edit.content) backward = DeleteFile(filepath=edit.filepath) elif isinstance(edit, DeleteFile): contents = await self.readFile(edit.filepath) backward = AddFile(filepath=edit.filepath, content=contents) fs.delete_file(edit.filepath) elif isinstance(edit, RenameFile): fs.rename_file(edit.filepath, edit.new_filepath) backward = RenameFile(filepath=edit.new_filepath, new_filepath=edit.filepath) elif isinstance(edit, AddDirectory): fs.add_directory(edit.path) backward = DeleteDirectory(path=edit.path) elif isinstance(edit, DeleteDirectory): # This isn't atomic! backward_edits = [] for root, dirs, files in os.walk(edit.path, topdown=False): for f in files: path = os.path.join(root, f) edit_diff = await self.applyFileSystemEdit(DeleteFile(filepath=path)) backward_edits.append(edit_diff) for d in dirs: path = os.path.join(root, d) edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=path)) backward_edits.append(edit_diff) edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=edit.path)) backward_edits.append(edit_diff) backward_edits.reverse() backward = SequentialFileSystemEdit(edits=backward_edits) elif isinstance(edit, RenameDirectory): fs.rename_directory(edit.path, edit.new_path) backward = RenameDirectory(path=edit.new_path, new_path=edit.path) elif isinstance(edit, FileSystemEdit): diffs = [] for edit in edit.next_edit(): edit_diff = await self.applyFileSystemEdit(edit) diffs.append(edit_diff) backward = EditDiff.from_sequence(diffs=diffs).backward else: raise TypeError("Unknown FileSystemEdit type: " + str(type(edit))) return EditDiff( forward=edit, backward=backward ) @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: # Accept the websocket connection await websocket.accept() logger.debug(f"Accepted websocket connection from {websocket.client}") await websocket.send_json({"messageType": "connected", "data": {}}) # Start meilisearch try: await start_meilisearch() except Exception as e: logger.debug("Failed to start MeiliSearch") logger.debug(e) # Message handler def handle_msg(msg): message = json.loads(msg) if "messageType" not in message or "data" not in message: return message_type = message["messageType"] data = message["data"] logger.debug(f"Received IDE message: {message_type}") create_async_task( ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error) # Initialize the IDE Protocol Server ideProtocolServer = IdeProtocolServer(session_manager, websocket) if session_id is not None: session_manager.registered_ides[session_id] = ideProtocolServer other_msgs = await ideProtocolServer.initialize(session_id) posthog_logger.capture_event("session_started", { "session_id": ideProtocolServer.session_id}) for other_msg in other_msgs: handle_msg(other_msg) # Handle messages while AppStatus.should_exit is False: message = await websocket.receive_text() handle_msg(message) except WebSocketDisconnect as e: logger.debug("IDE websocket disconnected") except Exception as e: logger.debug(f"Error in ide websocket: {e}") err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) if session_id is not None and session_id in session_manager.sessions: await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) elif ideProtocolServer is not None: await ideProtocolServer.showMessage(f"Error in Continue server: {err_msg}") raise e finally: logger.debug("Closing ide websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() posthog_logger.capture_event("session_ended", { "session_id": ideProtocolServer.session_id}) if ideProtocolServer.session_id in session_manager.registered_ides: session_manager.registered_ides.pop(ideProtocolServer.session_id)