From a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58 Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Wed, 12 Jul 2023 00:32:51 -0700
Subject: finally found the snag

---
 continuedev/src/continuedev/core/abstract_sdk.py   |  4 +-
 continuedev/src/continuedev/core/autopilot.py      |  5 +-
 continuedev/src/continuedev/core/sdk.py            | 26 +++++----
 .../src/continuedev/libs/util/create_async_task.py |  3 +-
 continuedev/src/continuedev/libs/util/errors.py    |  2 +
 continuedev/src/continuedev/server/gui.py          | 11 +++-
 continuedev/src/continuedev/server/ide.py          | 67 ++++++++++++++++------
 continuedev/src/continuedev/server/ide_protocol.py |  8 +--
 .../src/continuedev/server/session_manager.py      |  5 +-
 .../src/continuedev/server/state_manager.py        | 21 -------
 continuedev/src/continuedev/steps/chat.py          | 25 ++++----
 11 files changed, 99 insertions(+), 78 deletions(-)
 create mode 100644 continuedev/src/continuedev/libs/util/errors.py
 delete mode 100644 continuedev/src/continuedev/server/state_manager.py

(limited to 'continuedev')

diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 7bd3da6c..94d7be10 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC):
     async def get_user_secret(self, env_var: str, prompt: str) -> str:
         pass
 
-    @abstractproperty
-    def config(self) -> ContinueConfig:
-        pass
+    config: ContinueConfig
 
     @abstractmethod
     def set_loading_message(self, message: str):
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 3f07e270..ac00e4f0 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -251,7 +251,7 @@ class Autopilot(ContinueBaseModel):
         #     i -= 1
 
         capture_event(self.continue_sdk.ide.unique_id, 'step run', {
-                      'step_name': step.name, 'params': step.dict()})
+            'step_name': step.name, 'params': step.dict()})
 
         if not is_future_step:
             # Check manual edits buffer, clear out if needed by creating a ManualEditStep
@@ -290,7 +290,8 @@ class Autopilot(ContinueBaseModel):
                 e)
 
             # Attach an InternalErrorObservation to the step and unhide it.
-            print(f"Error while running step: \n{error_string}\n{error_title}")
+            print(
+                f"Error while running step: \n{error_string}\n{error_title}")
             capture_event(self.continue_sdk.ide.unique_id, 'step error', {
                 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
 
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 8649cd58..a3441ad9 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -89,6 +89,20 @@ class ContinueSDK(AbstractContinueSDK):
         self.__autopilot = autopilot
         self.models = Models(self)
         self.context = autopilot.context
+        self.config = self._load_config()
+
+    config: ContinueConfig
+
+    def _load_config(self) -> ContinueConfig:
+        dir = self.ide.workspace_directory
+        yaml_path = os.path.join(dir, '.continue', 'config.yaml')
+        json_path = os.path.join(dir, '.continue', 'config.json')
+        if os.path.exists(yaml_path):
+            return load_config(yaml_path)
+        elif os.path.exists(json_path):
+            return load_config(json_path)
+        else:
+            return load_global_config()
 
     @property
     def history(self) -> History:
@@ -166,18 +180,6 @@ class ContinueSDK(AbstractContinueSDK):
     async def get_user_secret(self, env_var: str, prompt: str) -> str:
         return await self.ide.getUserSecret(env_var)
 
-    @property
-    def config(self) -> ContinueConfig:
-        dir = self.ide.workspace_directory
-        yaml_path = os.path.join(dir, '.continue', 'config.yaml')
-        json_path = os.path.join(dir, '.continue', 'config.json')
-        if os.path.exists(yaml_path):
-            return load_config(yaml_path)
-        elif os.path.exists(json_path):
-            return load_config(json_path)
-        else:
-            return load_global_config()
-
     def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
         context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)
                        ) if only_editing else self.__autopilot._highlighted_ranges
diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py
index 62ff30ec..354cea82 100644
--- a/continuedev/src/continuedev/libs/util/create_async_task.py
+++ b/continuedev/src/continuedev/libs/util/create_async_task.py
@@ -14,7 +14,8 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):
         try:
             future.result()
         except Exception as e:
-            print("Exception caught from async task: ", e)
+            print("Exception caught from async task: ",
+                  '\n'.join(traceback.format_exception(e)))
             capture_event(unique_id or "None", "async_task_error", {
                 "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))
             })
diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py
new file mode 100644
index 00000000..46074cfc
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/errors.py
@@ -0,0 +1,2 @@
+class SessionNotFound(Exception):
+    pass
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index dbc063c8..21089f30 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,5 +1,6 @@
 import json
 from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
 from typing import Any, List, Type, TypeVar, Union
 from pydantic import BaseModel
 import traceback
@@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
         self.session = session
 
     async def _send_json(self, message_type: str, data: Any):
+        if self.websocket.client_state == WebSocketState.DISCONNECTED:
+            return
         await self.websocket.send_json({
             "messageType": message_type,
             "data": data
@@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
         protocol.websocket = websocket
 
         # Update any history that may have happened before connection
-        await protocol.send_state_update()
+        # await protocol.send_state_update()
 
         while AppStatus.should_exit is False:
             message = await websocket.receive_text()
@@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
             data = message["data"]
 
             protocol.handle_json(message_type, data)
-
+    except WebSocketDisconnect as e:
+        print("GUI websocket disconnected")
     except Exception as e:
         print("ERROR in gui websocket: ", e)
         capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
@@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
         raise e
     finally:
         print("Closing gui websocket")
-        await websocket.close()
+        if websocket.client_state != WebSocketState.DISCONNECTED:
+            await websocket.close()
         session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 782c2ba6..400ad740 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,6 +5,7 @@ import os
 from typing import Any, Dict, List, Type, TypeVar, Union
 import uuid
 from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
 from uvicorn.main import Server
 import traceback
 
@@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager
 from .ide_protocol import AbstractIdeProtocolServer
 import asyncio
 from ..libs.util.create_async_task import create_async_task
+import nest_asyncio
+nest_asyncio.apply()
 
 
 router = APIRouter(prefix="/ide", tags=["ide"])
@@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
         self.websocket = websocket
         self.session_manager = session_manager
 
+    workspace_directory: str
+
+    async def initialize(self) -> List[str]:
+        await self._send_json("workspaceDirectory", {})
+        other_msgs = []
+        while True:
+            msg_string = await self.websocket.receive_text()
+            message = json.loads(msg_string)
+            if "messageType" not in message or "data" not in message:
+                continue
+            message_type = message["messageType"]
+            data = message["data"]
+            # if message_type == "openGUI":
+            #     await self.openGUI()
+            if message_type == "workspaceDirectory":
+                self.workspace_directory = data["workspaceDirectory"]
+                break
+            else:
+                other_msgs.append(msg_string)
+        return other_msgs
+
     async def _send_json(self, message_type: str, data: Any):
+        if self.websocket.client_state == WebSocketState.DISCONNECTED:
+            return
         await self.websocket.send_json({
             "messageType": message_type,
             "data": data
@@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
             self.onMainUserInput(data["input"])
         elif message_type == "deleteAtIndex":
             self.onDeleteAtIndex(data["index"])
-        elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+        elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]:
             self.sub_queue.post(message_type, data)
+        elif message_type == "workspaceDirectory":
+            self.workspace_directory = data["workspaceDirectory"]
         else:
             raise ValueError("Unknown message type", message_type)
 
@@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
         resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
         return resp.openFiles
 
-    async def getWorkspaceDirectory(self) -> str:
-        resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
-        return resp.workspaceDirectory
-
     async def get_unique_id(self) -> str:
         resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
         return resp.uniqueId
 
-    @property
-    def workspace_directory(self) -> str:
-        return asyncio.run(self.getWorkspaceDirectory())
-
     @cached_property_no_none
     def unique_id(self) -> str:
         return asyncio.run(self.get_unique_id())
@@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket):
         print("Accepted websocket connection from, ", websocket.client)
         await websocket.send_json({"messageType": "connected", "data": {}})
 
-        ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
-        while AppStatus.should_exit is False:
-            message = await websocket.receive_text()
-            message = json.loads(message)
+        def handle_msg(msg):
+            message = json.loads(msg)
 
             if "messageType" not in message or "data" not in message:
-                continue
+                return
             message_type = message["messageType"]
             data = message["data"]
 
-            await ideProtocolServer.handle_json(message_type, data)
+            create_async_task(
+                ideProtocolServer.handle_json(message_type, data))
+
+        ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+        other_msgs = await ideProtocolServer.initialize()
+
+        for other_msg in other_msgs:
+            handle_msg(other_msg)
+
+        while AppStatus.should_exit is False:
+            message = await websocket.receive_text()
+            handle_msg(message)
 
         print("Closing ide websocket")
-        await websocket.close()
+    except WebSocketDisconnect as e:
+        print("IDE wbsocket disconnected")
     except Exception as e:
         print("Error in ide websocket: ", e)
         capture_event(ideProtocolServer.unique_id, "gui_error", {
                       "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
-        await websocket.close()
         raise e
+    finally:
+        if websocket.client_state != WebSocketState.DISCONNECTED:
+            await websocket.close()
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..69cb6c10 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -15,10 +15,6 @@ class AbstractIdeProtocolServer(ABC):
     def showSuggestion(self, file_edit: FileEdit):
         """Show a suggestion to the user"""
 
-    @abstractmethod
-    async def getWorkspaceDirectory(self):
-        """Get the workspace directory"""
-
     @abstractmethod
     async def setFileOpen(self, filepath: str, open: bool = True):
         """Set whether a file is open"""
@@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC):
     async def showDiff(self, filepath: str, replacement: str, step_index: int):
         """Show a diff"""
 
-    @abstractproperty
-    def workspace_directory(self) -> str:
-        """Get the workspace directory"""
+    workspace_directory: str
 
     @abstractproperty
     def unique_id(self) -> str:
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 873a379e..7147dcfa 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -9,11 +9,13 @@ from ..core.main import FullState
 from ..core.autopilot import Autopilot
 from .ide_protocol import AbstractIdeProtocolServer
 from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
 
 
 class Session:
     session_id: str
     autopilot: Autopilot
+    # The GUI websocket for the session
     ws: Union[WebSocket, None]
 
     def __init__(self, session_id: str, autopilot: Autopilot):
@@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot):
 
 class SessionManager:
     sessions: Dict[str, Session] = {}
-    _event_loop: Union[BaseEventLoop, None] = None
 
     def get_session(self, session_id: str) -> Session:
         if session_id not in self.sessions:
@@ -67,6 +68,8 @@ class SessionManager:
         print("Registered websocket for session", session_id)
 
     async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+        if session_id not in self.sessions:
+            raise SessionNotFound(f"Session {session_id} not found")
         if self.sessions[session_id].ws is None:
             print(f"Session {session_id} has no websocket")
             return
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
-    """
-    A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
-    """
-
-    def __init__(self, ws: WebSocket):
-        self.ws = ws
-
-    def _send_update(self, updates: List[StateUpdate]):
-        self.ws.send_json(
-            [update.dict() for update in updates]
-        )
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index e1e041d0..14a1cd41 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -27,16 +27,21 @@ class SimpleChatStep(Step):
     async def run(self, sdk: ContinueSDK):
         completion = ""
         messages = self.messages or await sdk.get_chat_context()
-        async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5):
-            if sdk.current_step_was_deleted():
-                # So that the message doesn't disappear
-                self.hide = False
-                return
-
-            if "content" in chunk:
-                self.description += chunk["content"]
-                completion += chunk["content"]
-                await sdk.update_ui()
+
+        generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5)
+        try:
+            async for chunk in generator:
+                if sdk.current_step_was_deleted():
+                    # So that the message doesn't disappear
+                    self.hide = False
+                    return
+
+                if "content" in chunk:
+                    self.description += chunk["content"]
+                    completion += chunk["content"]
+                    await sdk.update_ui()
+        finally:
+            await generator.aclose()
 
         self.name = (await sdk.models.gpt35.complete(
             f"Write a short title for the following chat message: {self.description}")).strip()
-- 
cgit v1.2.3-70-g09d2