diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-02 14:09:50 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-02 14:09:50 -0400 |
commit | 3157898f13a6990aff909e8e4af7f58515a28e8c (patch) | |
tree | 773ec943df45a679fa4a0f221b871b7248be50d7 /continuedev/src/continuedev/server | |
parent | 6b909caa8dcbd4bf3d1078ded1c12146944ab349 (diff) | |
parent | aea318b48dd7e15df16eca12ba59c677671869aa (diff) | |
download | sncontinue-3157898f13a6990aff909e8e4af7f58515a28e8c.tar.gz sncontinue-3157898f13a6990aff909e8e4af7f58515a28e8c.tar.bz2 sncontinue-3157898f13a6990aff909e8e4af7f58515a28e8c.zip |
Merge branch 'main' into docs
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 130 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 28 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 62 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/notebook.py | 198 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 86 |
7 files changed, 292 insertions, 235 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py new file mode 100644 index 00000000..3d1a5a82 --- /dev/null +++ b/continuedev/src/continuedev/server/gui.py @@ -0,0 +1,130 @@ +import json +from fastapi import Depends, Header, WebSocket, APIRouter +from typing import Any, Type, TypeVar, Union +from pydantic import BaseModel +from uvicorn.main import Server + +from .session_manager import SessionManager, session_manager, Session +from .gui_protocol import AbstractGUIProtocolServer +from ..libs.util.queue import AsyncSubscriptionQueue +import asyncio +import nest_asyncio +nest_asyncio.apply() + +router = APIRouter(prefix="/gui", tags=["gui"]) + +# Graceful shutdown by closing websockets +original_handler = Server.handle_exit + + +class AppStatus: + should_exit = False + + @staticmethod + def handle_exit(*args, **kwargs): + AppStatus.should_exit = True + print("Shutting down") + original_handler(*args, **kwargs) + + +Server.handle_exit = AppStatus.handle_exit + + +def session(x_continue_session_id: str = Header("anonymous")) -> Session: + return session_manager.get_session(x_continue_session_id) + + +def websocket_session(session_id: str) -> Session: + return session_manager.get_session(session_id) + + +T = TypeVar("T", bound=BaseModel) + +# You should probably abstract away the websocket stuff into a separate class + + +class GUIProtocolServer(AbstractGUIProtocolServer): + websocket: WebSocket + session: Session + sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() + + def __init__(self, session: Session): + self.session = session + + async def _send_json(self, message_type: str, data: Any): + await self.websocket.send_json({ + "messageType": message_type, + "data": data + }) + + async def _receive_json(self, message_type: str) -> Any: + return await self.sub_queue.get(message_type) + + async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: + await self._send_json(message_type, data) + resp = await self._receive_json(message_type) + return resp_model.parse_obj(resp) + + def handle_json(self, message_type: str, data: Any): + try: + if message_type == "main_input": + self.on_main_input(data["input"]) + elif message_type == "step_user_input": + self.on_step_user_input(data["input"], data["index"]) + elif message_type == "refinement_input": + self.on_refinement_input(data["input"], data["index"]) + elif message_type == "reverse_to_index": + self.on_reverse_to_index(data["index"]) + except Exception as e: + print(e) + + async def send_state_update(self): + state = self.session.autopilot.get_full_state().dict() + await self._send_json("state_update", { + "state": state + }) + + def on_main_input(self, input: str): + # Do something with user input + asyncio.create_task(self.session.autopilot.accept_user_input(input)) + + def on_reverse_to_index(self, index: int): + # Reverse the history to the given index + asyncio.create_task(self.session.autopilot.reverse_to_index(index)) + + def on_step_user_input(self, input: str, index: int): + asyncio.create_task( + self.session.autopilot.give_user_input(input, index)) + + def on_refinement_input(self, input: str, index: int): + asyncio.create_task( + self.session.autopilot.accept_refinement_input(input, index)) + + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): + await websocket.accept() + + print("Session started") + session_manager.register_websocket(session.session_id, websocket) + protocol = GUIProtocolServer(session) + protocol.websocket = websocket + + # Update any history that may have happened before connection + await protocol.send_state_update() + + while AppStatus.should_exit is False: + message = await websocket.receive_text() + print("Received message", message) + if type(message) is str: + message = json.loads(message) + + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + + protocol.handle_json(message_type, data) + + print("Closing websocket") + await websocket.close() diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py new file mode 100644 index 00000000..e32d80ef --- /dev/null +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -0,0 +1,28 @@ +from typing import Any +from abc import ABC, abstractmethod + + +class AbstractGUIProtocolServer(ABC): + @abstractmethod + async def handle_json(self, data: Any): + """Handle a json message""" + + @abstractmethod + def on_main_input(self, input: str): + """Called when the user inputs something""" + + @abstractmethod + def on_reverse_to_index(self, index: int): + """Called when the user requests reverse to a previous index""" + + @abstractmethod + def on_refinement_input(self, input: str, index: int): + """Called when the user inputs a refinement""" + + @abstractmethod + def on_step_user_input(self, input: str, index: int): + """Called when the user inputs a step""" + + @abstractmethod + async def send_state_update(self, state: dict): + """Send a state update to the client""" diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 167d9483..71017ce0 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -1,5 +1,6 @@ # This is a separate server from server/main.py import asyncio +import json import os from typing import Any, Dict, List, Type, TypeVar, Union import uuid @@ -11,7 +12,7 @@ from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RealFileSyste from ..models.main import Traceback from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit from pydantic import BaseModel -from .notebook import SessionManager, session_manager +from .gui import SessionManager, session_manager from .ide_protocol import AbstractIdeProtocolServer @@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def __init__(self, session_manager: SessionManager): self.session_manager = session_manager - async def _send_json(self, data: Any): - await self.websocket.send_json(data) + async def _send_json(self, message_type: str, data: Any): + await self.websocket.send_json({ + "messageType": message_type, + "data": data + }) async def _receive_json(self, message_type: str) -> Any: return await self.sub_queue.get(message_type) async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: - await self._send_json(data) + await self._send_json(message_type, data) resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) - async def handle_json(self, data: Any): - t = data["messageType"] - if t == "openNotebook": - await self.openNotebook() - elif t == "setFileOpen": + async def handle_json(self, message_type: str, data: Any): + if message_type == "openGUI": + await self.openGUI() + elif message_type == "setFileOpen": await self.setFileOpen(data["filepath"], data["open"]) - elif t == "fileEdits": + elif message_type == "fileEdits": fileEdits = list( map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])) self.onFileEdits(fileEdits) - elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: - self.sub_queue.post(t, data) + elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: + self.sub_queue.post(message_type, data) else: - raise ValueError("Unknown message type", t) + raise ValueError("Unknown message type", message_type) # ------------------------------- # # Request actions in IDE, doesn't matter which Session @@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def setFileOpen(self, filepath: str, open: bool = True): # Autopilot needs access to this. - await self.websocket.send_json({ - "messageType": "setFileOpen", + await self._send_json("setFileOpen", { "filepath": filepath, "open": open }) - async def openNotebook(self): + async def openGUI(self): session_id = self.session_manager.new_session(self) - await self._send_json({ - "messageType": "openNotebook", + await self._send_json("openGUI", { "sessionId": session_id }) async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: ids = [str(uuid.uuid4()) for _ in suggestions] for i in range(len(suggestions)): - self._send_json({ - "messageType": "showSuggestion", + self._send_json("showSuggestion", { "suggestion": suggestions[i], "suggestionId": ids[i] }) @@ -148,7 +148,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self._receive_json(ShowSuggestionResponse) for i in range(len(suggestions)) ]) # WORKING ON THIS FLOW HERE. Fine now to just await for response, instead of doing something fancy with a "waiting" state on the autopilot. - # Just need connect the suggestionId to the IDE (and the notebook) + # Just need connect the suggestionId to the IDE (and the gui) return any([r.accepted for r in responses]) # ------------------------------- # @@ -168,11 +168,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer): # Access to Autopilot (so SessionManager) pass - def onCloseNotebook(self, session_id: str): + def onCloseGUI(self, session_id: str): # Accesss to SessionManager pass - def onOpenNotebookRequest(self): + def onOpenGUIRequest(self): pass def onFileEdits(self, edits: List[FileEditWithFullContents]): @@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def saveFile(self, filepath: str): """Save a file""" - await self._send_json({ - "messageType": "saveFile", + await self._send_json("saveFile", { "filepath": filepath }) @@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager) async def websocket_endpoint(websocket: WebSocket): await websocket.accept() print("Accepted websocket connection from, ", websocket.client) - await websocket.send_json({"messageType": "connected"}) + await websocket.send_json({"messageType": "connected", "data": {}}) ideProtocolServer.websocket = websocket while True: - data = await websocket.receive_json() - await ideProtocolServer.handle_json(data) + message = await websocket.receive_text() + message = json.loads(message) + + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + + await ideProtocolServer.handle_json(message_type, data) await websocket.close() diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 15d019b4..4f505e80 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -24,8 +24,8 @@ class AbstractIdeProtocolServer(ABC): """Set whether a file is open""" @abstractmethod - async def openNotebook(self): - """Open a notebook""" + async def openGUI(self): + """Open a GUI""" @abstractmethod async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: @@ -44,12 +44,12 @@ class AbstractIdeProtocolServer(ABC): """Called when a file system update is received""" @abstractmethod - def onCloseNotebook(self, session_id: str): - """Called when a notebook is closed""" + def onCloseGUI(self, session_id: str): + """Called when a GUI is closed""" @abstractmethod - def onOpenNotebookRequest(self): - """Called when a notebook is requested to be opened""" + def onOpenGUIRequest(self): + """Called when a GUI is requested to be opened""" @abstractmethod async def getOpenFiles(self) -> List[str]: diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 11ad1d8f..7b7124de 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,14 +1,15 @@ +import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router -from .notebook import router as notebook_router +from .gui import router as gui_router import uvicorn import argparse app = FastAPI() app.include_router(ide_router) -app.include_router(notebook_router) +app.include_router(gui_router) # Add CORS support app.add_middleware( @@ -32,7 +33,11 @@ args = parser.parse_args() def run_server(): - uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini") + if os.path.exists("logging.yaml"): + uvicorn.run(app, host="0.0.0.0", port=args.port, + log_config="logging.yaml") + else: + uvicorn.run(app, host="0.0.0.0", port=args.port) if __name__ == "__main__": diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py deleted file mode 100644 index c26920f5..00000000 --- a/continuedev/src/continuedev/server/notebook.py +++ /dev/null @@ -1,198 +0,0 @@ -from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter -from typing import Any, Dict, List, Union -from uuid import uuid4 -from pydantic import BaseModel -from uvicorn.main import Server - -from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.policy import DemoPolicy -from ..libs.core import Autopilot, FullState, History, Step -from ..libs.steps.nate import ImplementAbstractMethodStep -from ..libs.observation import Observation -from dotenv import load_dotenv -from ..libs.llm.openai import OpenAI -from .ide_protocol import AbstractIdeProtocolServer -import os -import asyncio -import nest_asyncio -nest_asyncio.apply() - -load_dotenv() -openai_api_key = os.getenv("OPENAI_API_KEY") - -router = APIRouter(prefix="/notebook", tags=["notebook"]) - -# Graceful shutdown by closing websockets -original_handler = Server.handle_exit - - -class AppStatus: - should_exit = False - - @staticmethod - def handle_exit(*args, **kwargs): - AppStatus.should_exit = True - print("Shutting down") - original_handler(*args, **kwargs) - - -Server.handle_exit = AppStatus.handle_exit - - -class Session: - session_id: str - autopilot: Autopilot - ws: Union[WebSocket, None] - - def __init__(self, session_id: str, autopilot: Autopilot): - self.session_id = session_id - self.autopilot = autopilot - self.ws = None - - -class DemoAutopilot(Autopilot): - first_seen: bool = False - cumulative_edit_string = "" - - def handle_manual_edits(self, edits: List[FileEditWithFullContents]): - for edit in edits: - self.cumulative_edit_string += edit.fileEdit.replacement - self._manual_edits_buffer.append(edit) - # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. - # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - # FOR DEMO PURPOSES - if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: - self.cumulative_edit_string = "" - asyncio.create_task(self.run_from_step( - ImplementAbstractMethodStep())) - - -class SessionManager: - sessions: Dict[str, Session] = {} - _event_loop: Union[asyncio.BaseEventLoop, None] = None - - def get_session(self, session_id: str) -> Session: - if session_id not in self.sessions: - raise KeyError("Session ID not recognized") - return self.sessions[session_id] - - def new_session(self, ide: AbstractIdeProtocolServer) -> str: - cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py" - autopilot = DemoAutopilot(llm=OpenAI(api_key=openai_api_key), - policy=DemoPolicy(cmd=cmd), ide=ide) - session_id = str(uuid4()) - session = Session(session_id=session_id, autopilot=autopilot) - self.sessions[session_id] = session - - def on_update(state: FullState): - session_manager.send_ws_data(session_id, { - "messageType": "state", - "state": autopilot.get_full_state().dict() - }) - - autopilot.on_update(on_update) - asyncio.create_task(autopilot.run_policy()) - return session_id - - def remove_session(self, session_id: str): - del self.sessions[session_id] - - def register_websocket(self, session_id: str, ws: WebSocket): - self.sessions[session_id].ws = ws - print("Registered websocket for session", session_id) - - def send_ws_data(self, session_id: str, data: Any): - if self.sessions[session_id].ws is None: - print(f"Session {session_id} has no websocket") - return - - async def a(): - await self.sessions[session_id].ws.send_json(data) - - # Run coroutine in background - if self._event_loop is None or self._event_loop.is_closed(): - self._event_loop = asyncio.new_event_loop() - self._event_loop.run_until_complete(a()) - self._event_loop.close() - else: - self._event_loop.run_until_complete(a()) - self._event_loop.close() - - -session_manager = SessionManager() - - -def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return session_manager.get_session(x_continue_session_id) - - -def websocket_session(session_id: str) -> Session: - return session_manager.get_session(session_id) - - -class StartSessionBody(BaseModel): - config_file_path: Union[str, None] - - -class StartSessionResp(BaseModel): - session_id: str - - -@router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): - await websocket.accept() - - session_manager.register_websocket(session.session_id, websocket) - data = await websocket.receive_text() - # Update any history that may have happened before connection - await websocket.send_json({ - "messageType": "state", - "state": session_manager.get_session(session.session_id).autopilot.get_full_state().dict() - }) - print("Session started", data) - while AppStatus.should_exit is False: - data = await websocket.receive_json() - print("Received data", data) - - if "messageType" not in data: - continue - messageType = data["messageType"] - - try: - if messageType == "main_input": - # Do something with user input - asyncio.create_task( - session.autopilot.accept_user_input(data["value"])) - elif messageType == "step_user_input": - asyncio.create_task( - session.autopilot.give_user_input(data["value"], data["index"])) - elif messageType == "refinement_input": - asyncio.create_task( - session.autopilot.accept_refinement_input(data["value"], data["index"])) - elif messageType == "reverse": - # Reverse the history to the given index - asyncio.create_task( - session.autopilot.reverse_to_index(data["index"])) - except Exception as e: - print(e) - - print("Closing websocket") - await websocket.close() - - -@router.post("/run") -def request_run(step: Step, session=Depends(session)): - """Tell an autopilot to take a specific action.""" - asyncio.create_task(session.autopilot.run_from_step(step)) - return "Success" - - -@router.get("/history") -def get_history(session=Depends(session)) -> History: - return session.autopilot.history - - -@router.post("/observation") -def post_observation(observation: Observation, session=Depends(session)): - asyncio.create_task(session.autopilot.run_from_observation(observation)) - return "Success" diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py new file mode 100644 index 00000000..5598e140 --- /dev/null +++ b/continuedev/src/continuedev/server/session_manager.py @@ -0,0 +1,86 @@ +from fastapi import WebSocket +from typing import Any, Dict, List, Union +from uuid import uuid4 + +from ..models.filesystem_edit import FileEditWithFullContents +from ..core.policy import DemoPolicy +from ..core.main import FullState +from ..core.autopilot import Autopilot +from ..libs.steps.nate import ImplementAbstractMethodStep +from .ide_protocol import AbstractIdeProtocolServer +import asyncio +import nest_asyncio +nest_asyncio.apply() + + +class Session: + session_id: str + autopilot: Autopilot + ws: Union[WebSocket, None] + + def __init__(self, session_id: str, autopilot: Autopilot): + self.session_id = session_id + self.autopilot = autopilot + self.ws = None + + +class DemoAutopilot(Autopilot): + first_seen: bool = False + cumulative_edit_string = "" + + def handle_manual_edits(self, edits: List[FileEditWithFullContents]): + for edit in edits: + self.cumulative_edit_string += edit.fileEdit.replacement + self._manual_edits_buffer.append(edit) + # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. + # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) + # FOR DEMO PURPOSES + if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: + self.cumulative_edit_string = "" + asyncio.create_task(self.run_from_step( + ImplementAbstractMethodStep())) + + +class SessionManager: + sessions: Dict[str, Session] = {} + _event_loop: Union[asyncio.BaseEventLoop, None] = None + + def get_session(self, session_id: str) -> Session: + if session_id not in self.sessions: + raise KeyError("Session ID not recognized") + return self.sessions[session_id] + + def new_session(self, ide: AbstractIdeProtocolServer) -> str: + autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide) + session_id = str(uuid4()) + session = Session(session_id=session_id, autopilot=autopilot) + self.sessions[session_id] = session + + async def on_update(state: FullState): + await session_manager.send_ws_data(session_id, "state_update", { + "state": autopilot.get_full_state().dict() + }) + + autopilot.on_update(on_update) + asyncio.create_task(autopilot.run_policy()) + return session_id + + def remove_session(self, session_id: str): + del self.sessions[session_id] + + def register_websocket(self, session_id: str, ws: WebSocket): + self.sessions[session_id].ws = ws + print("Registered websocket for session", session_id) + + async def send_ws_data(self, session_id: str, message_type: str, data: Any): + if self.sessions[session_id].ws is None: + print(f"Session {session_id} has no websocket") + return + + await self.sessions[session_id].ws.send_json({ + "messageType": message_type, + "data": data + }) + + +session_manager = SessionManager() |