diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-12 00:32:51 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-12 00:32:51 -0700 |
commit | a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58 (patch) | |
tree | 1284916971492ed4d4ac7c91295f6529b6209eaf /continuedev | |
parent | 67d6a3f0ea00e55aea47b4eeff4cdb0d8321ce2f (diff) | |
download | sncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.tar.gz sncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.tar.bz2 sncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.zip |
finally found the snag
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 26 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/create_async_task.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/errors.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 67 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/state_manager.py | 21 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 25 |
11 files changed, 99 insertions, 78 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 7bd3da6c..94d7be10 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC): async def get_user_secret(self, env_var: str, prompt: str) -> str: pass - @abstractproperty - def config(self) -> ContinueConfig: - pass + config: ContinueConfig @abstractmethod def set_loading_message(self, message: str): diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3f07e270..ac00e4f0 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -251,7 +251,7 @@ class Autopilot(ContinueBaseModel): # i -= 1 capture_event(self.continue_sdk.ide.unique_id, 'step run', { - 'step_name': step.name, 'params': step.dict()}) + 'step_name': step.name, 'params': step.dict()}) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep @@ -290,7 +290,8 @@ class Autopilot(ContinueBaseModel): e) # Attach an InternalErrorObservation to the step and unhide it. - print(f"Error while running step: \n{error_string}\n{error_title}") + print( + f"Error while running step: \n{error_string}\n{error_title}") capture_event(self.continue_sdk.ide.unique_id, 'step error', { 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 8649cd58..a3441ad9 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -89,6 +89,20 @@ class ContinueSDK(AbstractContinueSDK): self.__autopilot = autopilot self.models = Models(self) self.context = autopilot.context + self.config = self._load_config() + + config: ContinueConfig + + def _load_config(self) -> ContinueConfig: + dir = self.ide.workspace_directory + yaml_path = os.path.join(dir, '.continue', 'config.yaml') + json_path = os.path.join(dir, '.continue', 'config.json') + if os.path.exists(yaml_path): + return load_config(yaml_path) + elif os.path.exists(json_path): + return load_config(json_path) + else: + return load_global_config() @property def history(self) -> History: @@ -166,18 +180,6 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) - @property - def config(self) -> ContinueConfig: - dir = self.ide.workspace_directory - yaml_path = os.path.join(dir, '.continue', 'config.yaml') - json_path = os.path.join(dir, '.continue', 'config.json') - if os.path.exists(yaml_path): - return load_config(yaml_path) - elif os.path.exists(json_path): - return load_config(json_path) - else: - return load_global_config() - def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) ) if only_editing else self.__autopilot._highlighted_ranges diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 62ff30ec..354cea82 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -14,7 +14,8 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): try: future.result() except Exception as e: - print("Exception caught from async task: ", e) + print("Exception caught from async task: ", + '\n'.join(traceback.format_exception(e))) capture_event(unique_id or "None", "async_task_error", { "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) }) diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py new file mode 100644 index 00000000..46074cfc --- /dev/null +++ b/continuedev/src/continuedev/libs/util/errors.py @@ -0,0 +1,2 @@ +class SessionNotFound(Exception): + pass diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index dbc063c8..21089f30 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,5 +1,6 @@ import json from fastapi import Depends, Header, WebSocket, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from typing import Any, List, Type, TypeVar, Union from pydantic import BaseModel import traceback @@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.session = session async def _send_json(self, message_type: str, data: Any): + if self.websocket.client_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data @@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.websocket = websocket # Update any history that may have happened before connection - await protocol.send_state_update() + # await protocol.send_state_update() while AppStatus.should_exit is False: message = await websocket.receive_text() @@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we data = message["data"] protocol.handle_json(message_type, data) - + except WebSocketDisconnect as e: + print("GUI websocket disconnected") except Exception as e: print("ERROR in gui websocket: ", e) capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { @@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we raise e finally: print("Closing gui websocket") - await websocket.close() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 782c2ba6..400ad740 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -5,6 +5,7 @@ import os from typing import Any, Dict, List, Type, TypeVar, Union import uuid from fastapi import WebSocket, Body, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from uvicorn.main import Server import traceback @@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager from .ide_protocol import AbstractIdeProtocolServer import asyncio from ..libs.util.create_async_task import create_async_task +import nest_asyncio +nest_asyncio.apply() router = APIRouter(prefix="/ide", tags=["ide"]) @@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.websocket = websocket self.session_manager = session_manager + workspace_directory: str + + async def initialize(self) -> List[str]: + await self._send_json("workspaceDirectory", {}) + other_msgs = [] + while True: + msg_string = await self.websocket.receive_text() + message = json.loads(msg_string) + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + # if message_type == "openGUI": + # await self.openGUI() + if message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] + break + else: + other_msgs.append(msg_string) + return other_msgs + async def _send_json(self, message_type: str, data: Any): + if self.websocket.client_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data @@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) - elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]: + elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]: self.sub_queue.post(message_type, data) + elif message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] else: raise ValueError("Unknown message type", message_type) @@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") return resp.openFiles - async def getWorkspaceDirectory(self) -> str: - resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory") - return resp.workspaceDirectory - async def get_unique_id(self) -> str: resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId") return resp.uniqueId - @property - def workspace_directory(self) -> str: - return asyncio.run(self.getWorkspaceDirectory()) - @cached_property_no_none def unique_id(self) -> str: return asyncio.run(self.get_unique_id()) @@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket): print("Accepted websocket connection from, ", websocket.client) await websocket.send_json({"messageType": "connected", "data": {}}) - ideProtocolServer = IdeProtocolServer(session_manager, websocket) - - while AppStatus.should_exit is False: - message = await websocket.receive_text() - message = json.loads(message) + def handle_msg(msg): + message = json.loads(msg) if "messageType" not in message or "data" not in message: - continue + return message_type = message["messageType"] data = message["data"] - await ideProtocolServer.handle_json(message_type, data) + create_async_task( + ideProtocolServer.handle_json(message_type, data)) + + ideProtocolServer = IdeProtocolServer(session_manager, websocket) + other_msgs = await ideProtocolServer.initialize() + + for other_msg in other_msgs: + handle_msg(other_msg) + + while AppStatus.should_exit is False: + message = await websocket.receive_text() + handle_msg(message) print("Closing ide websocket") - await websocket.close() + except WebSocketDisconnect as e: + print("IDE wbsocket disconnected") except Exception as e: print("Error in ide websocket: ", e) capture_event(ideProtocolServer.unique_id, "gui_error", { "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) - await websocket.close() raise e + finally: + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index dfdca504..69cb6c10 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -16,10 +16,6 @@ class AbstractIdeProtocolServer(ABC): """Show a suggestion to the user""" @abstractmethod - async def getWorkspaceDirectory(self): - """Get the workspace directory""" - - @abstractmethod async def setFileOpen(self, filepath: str, open: bool = True): """Set whether a file is open""" @@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC): async def showDiff(self, filepath: str, replacement: str, step_index: int): """Show a diff""" - @abstractproperty - def workspace_directory(self) -> str: - """Get the workspace directory""" + workspace_directory: str @abstractproperty def unique_id(self) -> str: diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 873a379e..7147dcfa 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -9,11 +9,13 @@ from ..core.main import FullState from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task +from ..libs.util.errors import SessionNotFound class Session: session_id: str autopilot: Autopilot + # The GUI websocket for the session ws: Union[WebSocket, None] def __init__(self, session_id: str, autopilot: Autopilot): @@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot): class SessionManager: sessions: Dict[str, Session] = {} - _event_loop: Union[BaseEventLoop, None] = None def get_session(self, session_id: str) -> Session: if session_id not in self.sessions: @@ -67,6 +68,8 @@ class SessionManager: print("Registered websocket for session", session_id) async def send_ws_data(self, session_id: str, message_type: str, data: Any): + if session_id not in self.sessions: + raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: print(f"Session {session_id} has no websocket") return diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py deleted file mode 100644 index c9bd760b..00000000 --- a/continuedev/src/continuedev/server/state_manager.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, List, Tuple, Union -from fastapi import WebSocket -from pydantic import BaseModel -from ..core.main import FullState - -# State updates represented as (path, replacement) pairs -StateUpdate = Tuple[List[Union[str, int]], Any] - - -class StateManager: - """ - A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. - """ - - def __init__(self, ws: WebSocket): - self.ws = ws - - def _send_update(self, updates: List[StateUpdate]): - self.ws.send_json( - [update.dict() for update in updates] - ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index e1e041d0..14a1cd41 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -27,16 +27,21 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): completion = "" messages = self.messages or await sdk.get_chat_context() - async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5): - if sdk.current_step_was_deleted(): - # So that the message doesn't disappear - self.hide = False - return - - if "content" in chunk: - self.description += chunk["content"] - completion += chunk["content"] - await sdk.update_ui() + + generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) + try: + async for chunk in generator: + if sdk.current_step_was_deleted(): + # So that the message doesn't disappear + self.hide = False + return + + if "content" in chunk: + self.description += chunk["content"] + completion += chunk["content"] + await sdk.update_ui() + finally: + await generator.aclose() self.name = (await sdk.models.gpt35.complete( f"Write a short title for the following chat message: {self.description}")).strip() |