summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-05-31 16:13:01 -0400
committerNate Sesti <sestinj@gmail.com>2023-05-31 16:13:01 -0400
commitb2ddcd0e96aaf7604d197809de7f47dd51072ff2 (patch)
tree6036f665f401d86b13d910f91aadacb7411806e5 /continuedev/src/continuedev/server
parent8d59100b3194cc8d122708523226968899efb5e1 (diff)
downloadsncontinue-b2ddcd0e96aaf7604d197809de7f47dd51072ff2.tar.gz
sncontinue-b2ddcd0e96aaf7604d197809de7f47dd51072ff2.tar.bz2
sncontinue-b2ddcd0e96aaf7604d197809de7f47dd51072ff2.zip
checkpoint! protocol reform and it works now
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/ide.py50
-rw-r--r--continuedev/src/continuedev/server/main.py2
-rw-r--r--continuedev/src/continuedev/server/notebook.py207
-rw-r--r--continuedev/src/continuedev/server/notebook_protocol.py28
-rw-r--r--continuedev/src/continuedev/server/session_manager.py101
5 files changed, 228 insertions, 160 deletions
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index dd1dc463..50296841 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -1,5 +1,6 @@
# This is a separate server from server/main.py
import asyncio
+import json
import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
@@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def __init__(self, session_manager: SessionManager):
self.session_manager = session_manager
- async def _send_json(self, data: Any):
- await self.websocket.send_json(data)
+ async def _send_json(self, message_type: str, data: Any):
+ await self.websocket.send_json({
+ "messageType": message_type,
+ "data": data
+ })
async def _receive_json(self, message_type: str) -> Any:
return await self.sub_queue.get(message_type)
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
- await self._send_json(data)
+ await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
- async def handle_json(self, data: Any):
- t = data["messageType"]
- if t == "openNotebook":
+ async def handle_json(self, message_type: str, data: Any):
+ if message_type == "openNotebook":
await self.openNotebook()
- elif t == "setFileOpen":
+ elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
- elif t == "fileEdits":
+ elif message_type == "fileEdits":
fileEdits = list(
map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))
self.onFileEdits(fileEdits)
- elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
- self.sub_queue.post(t, data)
+ elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
+ self.sub_queue.post(message_type, data)
else:
- raise ValueError("Unknown message type", t)
+ raise ValueError("Unknown message type", message_type)
# ------------------------------- #
# Request actions in IDE, doesn't matter which Session
@@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def setFileOpen(self, filepath: str, open: bool = True):
# Agent needs access to this.
- await self.websocket.send_json({
- "messageType": "setFileOpen",
+ await self._send_json("setFileOpen", {
"filepath": filepath,
"open": open
})
async def openNotebook(self):
session_id = self.session_manager.new_session(self)
- await self._send_json({
- "messageType": "openNotebook",
+ await self._send_json("openNotebook", {
"sessionId": session_id
})
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
ids = [str(uuid.uuid4()) for _ in suggestions]
for i in range(len(suggestions)):
- self._send_json({
- "messageType": "showSuggestion",
+ self._send_json("showSuggestion", {
"suggestion": suggestions[i],
"suggestionId": ids[i]
})
@@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def saveFile(self, filepath: str):
"""Save a file"""
- await self._send_json({
- "messageType": "saveFile",
+ await self._send_json("saveFile", {
"filepath": filepath
})
@@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager)
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
- await websocket.send_json({"messageType": "connected"})
+ await websocket.send_json({"messageType": "connected", "data": {}})
ideProtocolServer.websocket = websocket
while True:
- data = await websocket.receive_json()
- await ideProtocolServer.handle_json(data)
+ message = await websocket.receive_text()
+ message = json.loads(message)
+
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+
+ await ideProtocolServer.handle_json(message_type, data)
await websocket.close()
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 11ad1d8f..e87d5fa9 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -32,7 +32,7 @@ args = parser.parse_args()
def run_server():
- uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini")
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
if __name__ == "__main__":
diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py
index c5dcea31..edb61a45 100644
--- a/continuedev/src/continuedev/server/notebook.py
+++ b/continuedev/src/continuedev/server/notebook.py
@@ -1,18 +1,12 @@
-from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter
-from typing import Any, Dict, List, Union
-from uuid import uuid4
+import json
+from fastapi import Depends, Header, WebSocket, APIRouter
+from typing import Any, Type, TypeVar, Union
from pydantic import BaseModel
from uvicorn.main import Server
-from ..models.filesystem_edit import FileEditWithFullContents
-from ..core.policy import DemoPolicy
-from ..core.main import FullState, History, Step
-from ..core.agent import Agent
-from ..libs.steps.nate import ImplementAbstractMethodStep
-from ..core.observation import Observation
-from ..libs.llm.openai import OpenAI
-from .ide_protocol import AbstractIdeProtocolServer
-from ..core.env import get_env_var
+from .session_manager import SessionManager, session_manager, Session
+from .notebook_protocol import AbstractNotebookProtocolServer
+from ..libs.util.queue import AsyncSubscriptionQueue
import asyncio
import nest_asyncio
nest_asyncio.apply()
@@ -36,160 +30,99 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-class Session:
- session_id: str
- agent: Agent
- ws: Union[WebSocket, None]
-
- def __init__(self, session_id: str, agent: Agent):
- self.session_id = session_id
- self.agent = agent
- self.ws = None
-
-
-class DemoAgent(Agent):
- first_seen: bool = False
- cumulative_edit_string = ""
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- for edit in edits:
- self.cumulative_edit_string += edit.fileEdit.replacement
- self._manual_edits_buffer.append(edit)
- # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
- # FOR DEMO PURPOSES
- if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
- self.cumulative_edit_string = ""
- asyncio.create_task(self.run_from_step(
- ImplementAbstractMethodStep()))
-
-
-class SessionManager:
- sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
-
- def get_session(self, session_id: str) -> Session:
- if session_id not in self.sessions:
- raise KeyError("Session ID not recognized")
- return self.sessions[session_id]
-
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py"
- agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")),
- policy=DemoPolicy(cmd=cmd), ide=ide)
- session_id = str(uuid4())
- session = Session(session_id=session_id, agent=agent)
- self.sessions[session_id] = session
+def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return session_manager.get_session(x_continue_session_id)
- def on_update(state: FullState):
- session_manager.send_ws_data(session_id, {
- "messageType": "state",
- "state": agent.get_full_state().dict()
- })
- agent.on_update(on_update)
- asyncio.create_task(agent.run_policy())
- return session_id
+def websocket_session(session_id: str) -> Session:
+ return session_manager.get_session(session_id)
- def remove_session(self, session_id: str):
- del self.sessions[session_id]
- def register_websocket(self, session_id: str, ws: WebSocket):
- self.sessions[session_id].ws = ws
- print("Registered websocket for session", session_id)
+T = TypeVar("T", bound=BaseModel)
- def send_ws_data(self, session_id: str, data: Any):
- if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
- return
+# You should probably abstract away the websocket stuff into a separate class
- async def a():
- await self.sessions[session_id].ws.send_json(data)
- # Run coroutine in background
- if self._event_loop is None or self._event_loop.is_closed():
- self._event_loop = asyncio.new_event_loop()
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
- else:
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
+class NotebookProtocolServer(AbstractNotebookProtocolServer):
+ websocket: WebSocket
+ session: Session
+ sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ def __init__(self, session: Session):
+ self.session = session
-session_manager = SessionManager()
+ async def _send_json(self, data: Any):
+ await self.websocket.send_json(data)
+ async def _receive_json(self, message_type: str) -> Any:
+ return await self.sub_queue.get(message_type)
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+ async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
+ await self._send_json(data)
+ resp = await self._receive_json(message_type)
+ return resp_model.parse_obj(resp)
+ def handle_json(self, message_type: str, data: Any):
+ try:
+ if message_type == "main_input":
+ self.on_main_input(data["input"])
+ elif message_type == "step_user_input":
+ self.on_step_user_input(data["input"], data["index"])
+ elif message_type == "refinement_input":
+ self.on_refinement_input(data["input"], data["index"])
+ elif message_type == "reverse_to_index":
+ self.on_reverse_to_index(data["index"])
+ except Exception as e:
+ print(e)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+ async def send_state_update(self):
+ state = self.session.agent.get_full_state().dict()
+ await self._send_json({
+ "messageType": "state_update",
+ "state": state
+ })
+ def on_main_input(self, input: str):
+ # Do something with user input
+ asyncio.create_task(self.session.agent.accept_user_input(input))
-class StartSessionBody(BaseModel):
- config_file_path: Union[str, None]
+ def on_reverse_to_index(self, index: int):
+ # Reverse the history to the given index
+ asyncio.create_task(self.session.agent.reverse_to_index(index))
+ def on_step_user_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.give_user_input(input, index))
-class StartSessionResp(BaseModel):
- session_id: str
+ def on_refinement_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.accept_refinement_input(input, index))
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
await websocket.accept()
+ print("Session started")
session_manager.register_websocket(session.session_id, websocket)
- data = await websocket.receive_text()
+ protocol = NotebookProtocolServer(session)
+ protocol.websocket = websocket
+
# Update any history that may have happened before connection
- await websocket.send_json({
- "messageType": "state",
- "state": session_manager.get_session(session.session_id).agent.get_full_state().dict()
- })
- print("Session started", data)
+ await protocol.send_state_update()
+
while AppStatus.should_exit is False:
- data = await websocket.receive_json()
- print("Received data", data)
+ message = await websocket.receive_json()
+ print("Received message", message)
+ if type(message) is str:
+ message = json.loads(message)
- if "messageType" not in data:
+ if "messageType" not in message or "data" not in message:
continue
- messageType = data["messageType"]
+ message_type = message["messageType"]
+ data = message["data"]
- try:
- if messageType == "main_input":
- # Do something with user input
- asyncio.create_task(
- session.agent.accept_user_input(data["value"]))
- elif messageType == "step_user_input":
- asyncio.create_task(
- session.agent.give_user_input(data["value"], data["index"]))
- elif messageType == "refinement_input":
- asyncio.create_task(
- session.agent.accept_refinement_input(data["value"], data["index"]))
- elif messageType == "reverse":
- # Reverse the history to the given index
- asyncio.create_task(
- session.agent.reverse_to_index(data["index"]))
- except Exception as e:
- print(e)
+ protocol.handle_json(message_type, data)
print("Closing websocket")
await websocket.close()
-
-
-@router.post("/run")
-def request_run(step: Step, session=Depends(session)):
- """Tell an agent to take a specific action."""
- asyncio.create_task(session.agent.run_from_step(step))
- return "Success"
-
-
-@router.get("/history")
-def get_history(session=Depends(session)) -> History:
- return session.agent.history
-
-
-@router.post("/observation")
-def post_observation(observation: Observation, session=Depends(session)):
- asyncio.create_task(session.agent.run_from_observation(observation))
- return "Success"
diff --git a/continuedev/src/continuedev/server/notebook_protocol.py b/continuedev/src/continuedev/server/notebook_protocol.py
new file mode 100644
index 00000000..c2be82e0
--- /dev/null
+++ b/continuedev/src/continuedev/server/notebook_protocol.py
@@ -0,0 +1,28 @@
+from typing import Any
+from abc import ABC, abstractmethod
+
+
+class AbstractNotebookProtocolServer(ABC):
+ @abstractmethod
+ async def handle_json(self, data: Any):
+ """Handle a json message"""
+
+ @abstractmethod
+ def on_main_input(self, input: str):
+ """Called when the user inputs something"""
+
+ @abstractmethod
+ def on_reverse_to_index(self, index: int):
+ """Called when the user requests reverse to a previous index"""
+
+ @abstractmethod
+ def on_refinement_input(self, input: str, index: int):
+ """Called when the user inputs a refinement"""
+
+ @abstractmethod
+ def on_step_user_input(self, input: str, index: int):
+ """Called when the user inputs a step"""
+
+ @abstractmethod
+ async def send_state_update(self, state: dict):
+ """Send a state update to the client"""
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
new file mode 100644
index 00000000..b48c21b7
--- /dev/null
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -0,0 +1,101 @@
+from fastapi import WebSocket
+from typing import Any, Dict, List, Union
+from uuid import uuid4
+
+from ..models.filesystem_edit import FileEditWithFullContents
+from ..core.policy import DemoPolicy
+from ..core.main import FullState
+from ..core.agent import Agent
+from ..libs.steps.nate import ImplementAbstractMethodStep
+from .ide_protocol import AbstractIdeProtocolServer
+import asyncio
+import nest_asyncio
+nest_asyncio.apply()
+
+
+class Session:
+ session_id: str
+ agent: Agent
+ ws: Union[WebSocket, None]
+
+ def __init__(self, session_id: str, agent: Agent):
+ self.session_id = session_id
+ self.agent = agent
+ self.ws = None
+
+
+class DemoAgent(Agent):
+ first_seen: bool = False
+ cumulative_edit_string = ""
+
+ def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
+ for edit in edits:
+ self.cumulative_edit_string += edit.fileEdit.replacement
+ self._manual_edits_buffer.append(edit)
+ # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
+ # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
+ # FOR DEMO PURPOSES
+ if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
+ self.cumulative_edit_string = ""
+ asyncio.create_task(self.run_from_step(
+ ImplementAbstractMethodStep()))
+
+
+class SessionManager:
+ sessions: Dict[str, Session] = {}
+ _event_loop: Union[asyncio.BaseEventLoop, None] = None
+
+ def get_session(self, session_id: str) -> Session:
+ if session_id not in self.sessions:
+ raise KeyError("Session ID not recognized")
+ return self.sessions[session_id]
+
+ def new_session(self, ide: AbstractIdeProtocolServer) -> str:
+ agent = DemoAgent(policy=DemoPolicy(), ide=ide)
+ session_id = str(uuid4())
+ session = Session(session_id=session_id, agent=agent)
+ self.sessions[session_id] = session
+
+ async def on_update(state: FullState):
+ await session_manager.send_ws_data(session_id, "state_update", {
+ "state": agent.get_full_state().dict()
+ })
+
+ agent.on_update(on_update)
+ asyncio.create_task(agent.run_policy())
+ return session_id
+
+ def remove_session(self, session_id: str):
+ del self.sessions[session_id]
+
+ def register_websocket(self, session_id: str, ws: WebSocket):
+ self.sessions[session_id].ws = ws
+ print("Registered websocket for session", session_id)
+
+ async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if self.sessions[session_id].ws is None:
+ print(f"Session {session_id} has no websocket")
+ return
+
+ async def a():
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+
+ # Run coroutine in background
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+ return
+ if self._event_loop is None or self._event_loop.is_closed():
+ self._event_loop = asyncio.new_event_loop()
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+ else:
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+
+
+session_manager = SessionManager()