diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-05-23 23:45:12 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-05-23 23:45:12 -0400 |
commit | f53768612b1e2268697b5444e502032ef9f3fb3c (patch) | |
tree | 4ed49b73e6bd3c2f8fceffa9643973033f87af95 /continuedev/src/continuedev/server | |
download | sncontinue-f53768612b1e2268697b5444e502032ef9f3fb3c.tar.gz sncontinue-f53768612b1e2268697b5444e502032ef9f3fb3c.tar.bz2 sncontinue-f53768612b1e2268697b5444e502032ef9f3fb3c.zip |
copying from old repo
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 302 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 80 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 39 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/notebook.py | 198 |
4 files changed, 619 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py new file mode 100644 index 00000000..dd1dc463 --- /dev/null +++ b/continuedev/src/continuedev/server/ide.py @@ -0,0 +1,302 @@ +# This is a separate server from server/main.py +import asyncio +import os +from typing import Any, Dict, List, Type, TypeVar, Union +import uuid +from fastapi import WebSocket, Body, APIRouter +from uvicorn.main import Server + +from ..libs.util.queue import AsyncSubscriptionQueue +from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RealFileSystem +from ..models.main import Traceback +from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit +from pydantic import BaseModel +from .notebook import SessionManager, session_manager +from .ide_protocol import AbstractIdeProtocolServer + + +router = APIRouter(prefix="/ide", tags=["ide"]) + + +# Graceful shutdown by closing websockets +original_handler = Server.handle_exit + + +class AppStatus: + should_exit = False + + @staticmethod + def handle_exit(*args, **kwargs): + AppStatus.should_exit = True + print("Shutting down") + original_handler(*args, **kwargs) + + +Server.handle_exit = AppStatus.handle_exit + + +# TYPES # + + +class FileEditsUpdate(BaseModel): + messageType: str = "fileEdits" + fileEdits: List[FileEditWithFullContents] + + +class OpenFilesResponse(BaseModel): + messageType: str = "openFiles" + openFiles: List[str] + + +class HighlightedCodeResponse(BaseModel): + messageType: str = "highlightedCode" + highlightedCode: List[RangeInFile] + + +class ShowSuggestionRequest(BaseModel): + messageType: str = "showSuggestion" + suggestion: FileEdit + + +class ShowSuggestionResponse(BaseModel): + messageType: str = "showSuggestion" + suggestion: FileEdit + accepted: bool + + +class ReadFileResponse(BaseModel): + messageType: str = "readFile" + contents: str + + +class EditFileResponse(BaseModel): + messageType: str = "editFile" + fileEdit: FileEditWithFullContents + + +class WorkspaceDirectoryResponse(BaseModel): + messageType: str = "workspaceDirectory" + workspaceDirectory: str + + +T = TypeVar("T", bound=BaseModel) + + +class IdeProtocolServer(AbstractIdeProtocolServer): + websocket: WebSocket + session_manager: SessionManager + sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() + + def __init__(self, session_manager: SessionManager): + self.session_manager = session_manager + + async def _send_json(self, data: Any): + await self.websocket.send_json(data) + + async def _receive_json(self, message_type: str) -> Any: + return await self.sub_queue.get(message_type) + + async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: + await self._send_json(data) + resp = await self._receive_json(message_type) + return resp_model.parse_obj(resp) + + async def handle_json(self, data: Any): + t = data["messageType"] + if t == "openNotebook": + await self.openNotebook() + elif t == "setFileOpen": + await self.setFileOpen(data["filepath"], data["open"]) + elif t == "fileEdits": + fileEdits = list( + map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])) + self.onFileEdits(fileEdits) + elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: + self.sub_queue.post(t, data) + else: + raise ValueError("Unknown message type", t) + + # ------------------------------- # + # Request actions in IDE, doesn't matter which Session + def showSuggestion(): + pass + + async def setFileOpen(self, filepath: str, open: bool = True): + # Agent needs access to this. + await self.websocket.send_json({ + "messageType": "setFileOpen", + "filepath": filepath, + "open": open + }) + + async def openNotebook(self): + session_id = self.session_manager.new_session(self) + await self._send_json({ + "messageType": "openNotebook", + "sessionId": session_id + }) + + async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: + ids = [str(uuid.uuid4()) for _ in suggestions] + for i in range(len(suggestions)): + self._send_json({ + "messageType": "showSuggestion", + "suggestion": suggestions[i], + "suggestionId": ids[i] + }) + responses = await asyncio.gather(*[ + self._receive_json(ShowSuggestionResponse) + for i in range(len(suggestions)) + ]) # WORKING ON THIS FLOW HERE. Fine now to just await for response, instead of doing something fancy with a "waiting" state on the agent. + # Just need connect the suggestionId to the IDE (and the notebook) + return any([r.accepted for r in responses]) + + # ------------------------------- # + # Here needs to pass message onto the Agent OR Agent just subscribes. + # This is where you might have triggers: plugins can subscribe to certian events + # like file changes, tracebacks, etc... + + def onAcceptRejectSuggestion(self, suggestionId: str, accepted: bool): + pass + + def onTraceback(self, traceback: Traceback): + # Same as below, maybe not every agent? + for _, session in self.session_manager.sessions.items(): + session.agent.handle_traceback(traceback) + + def onFileSystemUpdate(self, update: FileSystemEdit): + # Access to Agent (so SessionManager) + pass + + def onCloseNotebook(self, session_id: str): + # Accesss to SessionManager + pass + + def onOpenNotebookRequest(self): + pass + + def onFileEdits(self, edits: List[FileEditWithFullContents]): + # Send the file edits to ALL agents. + # Maybe not ideal behavior + for _, session in self.session_manager.sessions.items(): + session.agent.handle_manual_edits(edits) + + # Request information. Session doesn't matter. + async def getOpenFiles(self) -> List[str]: + resp = await self._send_and_receive_json({ + "messageType": "openFiles" + }, OpenFilesResponse, "openFiles") + return resp.openFiles + + async def getWorkspaceDirectory(self) -> str: + resp = await self._send_and_receive_json({ + "messageType": "workspaceDirectory" + }, WorkspaceDirectoryResponse, "workspaceDirectory") + return resp.workspaceDirectory + + async def getHighlightedCode(self) -> List[RangeInFile]: + resp = await self._send_and_receive_json({ + "messageType": "highlightedCode" + }, HighlightedCodeResponse, "highlightedCode") + return resp.highlightedCode + + async def readFile(self, filepath: str) -> str: + """Read a file""" + resp = await self._send_and_receive_json({ + "messageType": "readFile", + "filepath": filepath + }, ReadFileResponse, "readFile") + return resp.contents + + async def saveFile(self, filepath: str): + """Save a file""" + await self._send_json({ + "messageType": "saveFile", + "filepath": filepath + }) + + async def readRangeInFile(self, range_in_file: RangeInFile) -> str: + """Read a range in a file""" + full_contents = await self.readFile(range_in_file.filepath) + return FileSystem.read_range_in_str(full_contents, range_in_file.range) + + async def editFile(self, edit: FileEdit) -> FileEditWithFullContents: + """Edit a file""" + resp = await self._send_and_receive_json({ + "messageType": "editFile", + "edit": edit.dict() + }, EditFileResponse, "editFile") + return resp.fileEdit + + async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: + """Apply a file edit""" + backward = None + fs = RealFileSystem() + if isinstance(edit, FileEdit): + file_edit = await self.editFile(edit) + _, diff = FileSystem.apply_edit_to_str( + file_edit.fileContents, file_edit.fileEdit) + backward = diff.backward + elif isinstance(edit, AddFile): + fs.write(edit.filepath, edit.content) + backward = DeleteFile(filepath=edit.filepath) + elif isinstance(edit, DeleteFile): + contents = await self.readFile(edit.filepath) + backward = AddFile(filepath=edit.filepath, content=contents) + fs.delete_file(edit.filepath) + elif isinstance(edit, RenameFile): + fs.rename_file(edit.filepath, edit.new_filepath) + backward = RenameFile(filepath=edit.new_filepath, + new_filepath=edit.filepath) + elif isinstance(edit, AddDirectory): + fs.add_directory(edit.path) + backward = DeleteDirectory(path=edit.path) + elif isinstance(edit, DeleteDirectory): + # This isn't atomic! + backward_edits = [] + for root, dirs, files in os.walk(edit.path, topdown=False): + for f in files: + path = os.path.join(root, f) + edit_diff = await self.applyFileSystemEdit(DeleteFile(filepath=path)) + backward_edits.append(edit_diff) + for d in dirs: + path = os.path.join(root, d) + edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=path)) + backward_edits.append(edit_diff) + + edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=edit.path)) + backward_edits.append(edit_diff) + backward_edits.reverse() + backward = SequentialFileSystemEdit(edits=backward_edits) + elif isinstance(edit, RenameDirectory): + fs.rename_directory(edit.path, edit.new_path) + backward = RenameDirectory(path=edit.new_path, new_path=edit.path) + elif isinstance(edit, FileSystemEdit): + diffs = [] + for edit in edit.next_edit(): + edit_diff = await self.applyFileSystemEdit(edit) + diffs.append(edit_diff) + backward = EditDiff.from_sequence(diffs=diffs).backward + else: + raise TypeError("Unknown FileSystemEdit type: " + str(type(edit))) + + return EditDiff( + forward=edit, + backward=backward + ) + + +ideProtocolServer = IdeProtocolServer(session_manager) + + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + print("Accepted websocket connection from, ", websocket.client) + await websocket.send_json({"messageType": "connected"}) + ideProtocolServer.websocket = websocket + while True: + data = await websocket.receive_json() + await ideProtocolServer.handle_json(data) + + await websocket.close() diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py new file mode 100644 index 00000000..15d019b4 --- /dev/null +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -0,0 +1,80 @@ +from typing import Any, List +from abc import ABC, abstractmethod + +from ..models.main import Traceback +from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff +from ..models.filesystem import RangeInFile + + +class AbstractIdeProtocolServer(ABC): + @abstractmethod + async def handle_json(self, data: Any): + """Handle a json message""" + + @abstractmethod + def showSuggestion(): + """Show a suggestion to the user""" + + @abstractmethod + async def getWorkspaceDirectory(self): + """Get the workspace directory""" + + @abstractmethod + async def setFileOpen(self, filepath: str, open: bool = True): + """Set whether a file is open""" + + @abstractmethod + async def openNotebook(self): + """Open a notebook""" + + @abstractmethod + async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: + """Show suggestions to the user and wait for a response""" + + @abstractmethod + def onAcceptRejectSuggestion(self, suggestionId: str, accepted: bool): + """Called when the user accepts or rejects a suggestion""" + + @abstractmethod + def onTraceback(self, traceback: Traceback): + """Called when a traceback is received""" + + @abstractmethod + def onFileSystemUpdate(self, update: FileSystemEdit): + """Called when a file system update is received""" + + @abstractmethod + def onCloseNotebook(self, session_id: str): + """Called when a notebook is closed""" + + @abstractmethod + def onOpenNotebookRequest(self): + """Called when a notebook is requested to be opened""" + + @abstractmethod + async def getOpenFiles(self) -> List[str]: + """Get a list of open files""" + + @abstractmethod + async def getHighlightedCode(self) -> List[RangeInFile]: + """Get a list of highlighted code""" + + @abstractmethod + async def readFile(self, filepath: str) -> str: + """Read a file""" + + @abstractmethod + async def readRangeInFile(self, range_in_file: RangeInFile) -> str: + """Read a range in a file""" + + @abstractmethod + async def editFile(self, edit: FileEdit): + """Edit a file""" + + @abstractmethod + async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: + """Apply a file edit""" + + @abstractmethod + async def saveFile(self, filepath: str): + """Save a file""" diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py new file mode 100644 index 00000000..11ad1d8f --- /dev/null +++ b/continuedev/src/continuedev/server/main.py @@ -0,0 +1,39 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from .ide import router as ide_router +from .notebook import router as notebook_router +import uvicorn +import argparse + +app = FastAPI() + +app.include_router(ide_router) +app.include_router(notebook_router) + +# Add CORS support +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +# add cli arg for server port +parser = argparse.ArgumentParser() +parser.add_argument("-p", "--port", help="server port", type=int, default=8000) +args = parser.parse_args() + + +def run_server(): + uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini") + + +if __name__ == "__main__": + run_server() diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py new file mode 100644 index 00000000..c9d4edc5 --- /dev/null +++ b/continuedev/src/continuedev/server/notebook.py @@ -0,0 +1,198 @@ +from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter +from typing import Any, Dict, List, Union +from uuid import uuid4 +from pydantic import BaseModel +from uvicorn.main import Server + +from ..models.filesystem_edit import FileEditWithFullContents +from ..libs.policy import DemoPolicy +from ..libs.core import Agent, FullState, History, Step +from ..libs.steps.nate import ImplementAbstractMethodStep +from ..libs.observation import Observation +from dotenv import load_dotenv +from ..libs.llm.openai import OpenAI +from .ide_protocol import AbstractIdeProtocolServer +import os +import asyncio +import nest_asyncio +nest_asyncio.apply() + +load_dotenv() +openai_api_key = os.getenv("OPENAI_API_KEY") + +router = APIRouter(prefix="/notebook", tags=["notebook"]) + +# Graceful shutdown by closing websockets +original_handler = Server.handle_exit + + +class AppStatus: + should_exit = False + + @staticmethod + def handle_exit(*args, **kwargs): + AppStatus.should_exit = True + print("Shutting down") + original_handler(*args, **kwargs) + + +Server.handle_exit = AppStatus.handle_exit + + +class Session: + session_id: str + agent: Agent + ws: Union[WebSocket, None] + + def __init__(self, session_id: str, agent: Agent): + self.session_id = session_id + self.agent = agent + self.ws = None + + +class DemoAgent(Agent): + first_seen: bool = False + cumulative_edit_string = "" + + def handle_manual_edits(self, edits: List[FileEditWithFullContents]): + for edit in edits: + self.cumulative_edit_string += edit.fileEdit.replacement + self._manual_edits_buffer.append(edit) + # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. + # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) + # FOR DEMO PURPOSES + if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: + self.cumulative_edit_string = "" + asyncio.create_task(self.run_from_step( + ImplementAbstractMethodStep())) + + +class SessionManager: + sessions: Dict[str, Session] = {} + _event_loop: Union[asyncio.BaseEventLoop, None] = None + + def get_session(self, session_id: str) -> Session: + if session_id not in self.sessions: + raise KeyError("Session ID not recognized") + return self.sessions[session_id] + + def new_session(self, ide: AbstractIdeProtocolServer) -> str: + cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py" + agent = DemoAgent(llm=OpenAI(api_key=openai_api_key), + policy=DemoPolicy(cmd=cmd), ide=ide) + session_id = str(uuid4()) + session = Session(session_id=session_id, agent=agent) + self.sessions[session_id] = session + + def on_update(state: FullState): + session_manager.send_ws_data(session_id, { + "messageType": "state", + "state": agent.get_full_state().dict() + }) + + agent.on_update(on_update) + asyncio.create_task(agent.run_policy()) + return session_id + + def remove_session(self, session_id: str): + del self.sessions[session_id] + + def register_websocket(self, session_id: str, ws: WebSocket): + self.sessions[session_id].ws = ws + print("Registered websocket for session", session_id) + + def send_ws_data(self, session_id: str, data: Any): + if self.sessions[session_id].ws is None: + print(f"Session {session_id} has no websocket") + return + + async def a(): + await self.sessions[session_id].ws.send_json(data) + + # Run coroutine in background + if self._event_loop is None or self._event_loop.is_closed(): + self._event_loop = asyncio.new_event_loop() + self._event_loop.run_until_complete(a()) + self._event_loop.close() + else: + self._event_loop.run_until_complete(a()) + self._event_loop.close() + + +session_manager = SessionManager() + + +def session(x_continue_session_id: str = Header("anonymous")) -> Session: + return session_manager.get_session(x_continue_session_id) + + +def websocket_session(session_id: str) -> Session: + return session_manager.get_session(session_id) + + +class StartSessionBody(BaseModel): + config_file_path: Union[str, None] + + +class StartSessionResp(BaseModel): + session_id: str + + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): + await websocket.accept() + + session_manager.register_websocket(session.session_id, websocket) + data = await websocket.receive_text() + # Update any history that may have happened before connection + await websocket.send_json({ + "messageType": "state", + "state": session_manager.get_session(session.session_id).agent.get_full_state().dict() + }) + print("Session started", data) + while AppStatus.should_exit is False: + data = await websocket.receive_json() + print("Received data", data) + + if "messageType" not in data: + continue + messageType = data["messageType"] + + try: + if messageType == "main_input": + # Do something with user input + asyncio.create_task( + session.agent.accept_user_input(data["value"])) + elif messageType == "step_user_input": + asyncio.create_task( + session.agent.give_user_input(data["value"], data["index"])) + elif messageType == "refinement_input": + asyncio.create_task( + session.agent.accept_refinement_input(data["value"], data["index"])) + elif messageType == "reverse": + # Reverse the history to the given index + asyncio.create_task( + session.agent.reverse_to_index(data["index"])) + except Exception as e: + print(e) + + print("Closing websocket") + await websocket.close() + + +@router.post("/run") +def request_run(step: Step, session=Depends(session)): + """Tell an agent to take a specific action.""" + asyncio.create_task(session.agent.run_from_step(step)) + return "Success" + + +@router.get("/history") +def get_history(session=Depends(session)) -> History: + return session.agent.history + + +@router.post("/observation") +def post_observation(observation: Observation, session=Depends(session)): + asyncio.create_task(session.agent.run_from_observation(observation)) + return "Success" |