diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 26 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/create_async_task.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/errors.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 67 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/state_manager.py | 21 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 25 | 
11 files changed, 99 insertions, 78 deletions
| diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 7bd3da6c..94d7be10 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          pass -    @abstractproperty -    def config(self) -> ContinueConfig: -        pass +    config: ContinueConfig      @abstractmethod      def set_loading_message(self, message: str): diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3f07e270..ac00e4f0 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -251,7 +251,7 @@ class Autopilot(ContinueBaseModel):          #     i -= 1          capture_event(self.continue_sdk.ide.unique_id, 'step run', { -                      'step_name': step.name, 'params': step.dict()}) +            'step_name': step.name, 'params': step.dict()})          if not is_future_step:              # Check manual edits buffer, clear out if needed by creating a ManualEditStep @@ -290,7 +290,8 @@ class Autopilot(ContinueBaseModel):                  e)              # Attach an InternalErrorObservation to the step and unhide it. -            print(f"Error while running step: \n{error_string}\n{error_title}") +            print( +                f"Error while running step: \n{error_string}\n{error_title}")              capture_event(self.continue_sdk.ide.unique_id, 'step error', {                  'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 8649cd58..a3441ad9 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -89,6 +89,20 @@ class ContinueSDK(AbstractContinueSDK):          self.__autopilot = autopilot          self.models = Models(self)          self.context = autopilot.context +        self.config = self._load_config() + +    config: ContinueConfig + +    def _load_config(self) -> ContinueConfig: +        dir = self.ide.workspace_directory +        yaml_path = os.path.join(dir, '.continue', 'config.yaml') +        json_path = os.path.join(dir, '.continue', 'config.json') +        if os.path.exists(yaml_path): +            return load_config(yaml_path) +        elif os.path.exists(json_path): +            return load_config(json_path) +        else: +            return load_global_config()      @property      def history(self) -> History: @@ -166,18 +180,6 @@ class ContinueSDK(AbstractContinueSDK):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          return await self.ide.getUserSecret(env_var) -    @property -    def config(self) -> ContinueConfig: -        dir = self.ide.workspace_directory -        yaml_path = os.path.join(dir, '.continue', 'config.yaml') -        json_path = os.path.join(dir, '.continue', 'config.json') -        if os.path.exists(yaml_path): -            return load_config(yaml_path) -        elif os.path.exists(json_path): -            return load_config(json_path) -        else: -            return load_global_config() -      def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:          context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)                         ) if only_editing else self.__autopilot._highlighted_ranges diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 62ff30ec..354cea82 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -14,7 +14,8 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):          try:              future.result()          except Exception as e: -            print("Exception caught from async task: ", e) +            print("Exception caught from async task: ", +                  '\n'.join(traceback.format_exception(e)))              capture_event(unique_id or "None", "async_task_error", {                  "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))              }) diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py new file mode 100644 index 00000000..46074cfc --- /dev/null +++ b/continuedev/src/continuedev/libs/util/errors.py @@ -0,0 +1,2 @@ +class SessionNotFound(Exception): +    pass diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index dbc063c8..21089f30 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,5 +1,6 @@  import json  from fastapi import Depends, Header, WebSocket, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect  from typing import Any, List, Type, TypeVar, Union  from pydantic import BaseModel  import traceback @@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          self.session = session      async def _send_json(self, message_type: str, data: Any): +        if self.websocket.client_state == WebSocketState.DISCONNECTED: +            return          await self.websocket.send_json({              "messageType": message_type,              "data": data @@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          protocol.websocket = websocket          # Update any history that may have happened before connection -        await protocol.send_state_update() +        # await protocol.send_state_update()          while AppStatus.should_exit is False:              message = await websocket.receive_text() @@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we              data = message["data"]              protocol.handle_json(message_type, data) - +    except WebSocketDisconnect as e: +        print("GUI websocket disconnected")      except Exception as e:          print("ERROR in gui websocket: ", e)          capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { @@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          raise e      finally:          print("Closing gui websocket") -        await websocket.close() +        if websocket.client_state != WebSocketState.DISCONNECTED: +            await websocket.close()          session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 782c2ba6..400ad740 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -5,6 +5,7 @@ import os  from typing import Any, Dict, List, Type, TypeVar, Union  import uuid  from fastapi import WebSocket, Body, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect  from uvicorn.main import Server  import traceback @@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager  from .ide_protocol import AbstractIdeProtocolServer  import asyncio  from ..libs.util.create_async_task import create_async_task +import nest_asyncio +nest_asyncio.apply()  router = APIRouter(prefix="/ide", tags=["ide"]) @@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          self.websocket = websocket          self.session_manager = session_manager +    workspace_directory: str + +    async def initialize(self) -> List[str]: +        await self._send_json("workspaceDirectory", {}) +        other_msgs = [] +        while True: +            msg_string = await self.websocket.receive_text() +            message = json.loads(msg_string) +            if "messageType" not in message or "data" not in message: +                continue +            message_type = message["messageType"] +            data = message["data"] +            # if message_type == "openGUI": +            #     await self.openGUI() +            if message_type == "workspaceDirectory": +                self.workspace_directory = data["workspaceDirectory"] +                break +            else: +                other_msgs.append(msg_string) +        return other_msgs +      async def _send_json(self, message_type: str, data: Any): +        if self.websocket.client_state == WebSocketState.DISCONNECTED: +            return          await self.websocket.send_json({              "messageType": message_type,              "data": data @@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              self.onMainUserInput(data["input"])          elif message_type == "deleteAtIndex":              self.onDeleteAtIndex(data["index"]) -        elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]: +        elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]:              self.sub_queue.post(message_type, data) +        elif message_type == "workspaceDirectory": +            self.workspace_directory = data["workspaceDirectory"]          else:              raise ValueError("Unknown message type", message_type) @@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")          return resp.openFiles -    async def getWorkspaceDirectory(self) -> str: -        resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory") -        return resp.workspaceDirectory -      async def get_unique_id(self) -> str:          resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")          return resp.uniqueId -    @property -    def workspace_directory(self) -> str: -        return asyncio.run(self.getWorkspaceDirectory()) -      @cached_property_no_none      def unique_id(self) -> str:          return asyncio.run(self.get_unique_id()) @@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket):          print("Accepted websocket connection from, ", websocket.client)          await websocket.send_json({"messageType": "connected", "data": {}}) -        ideProtocolServer = IdeProtocolServer(session_manager, websocket) - -        while AppStatus.should_exit is False: -            message = await websocket.receive_text() -            message = json.loads(message) +        def handle_msg(msg): +            message = json.loads(msg)              if "messageType" not in message or "data" not in message: -                continue +                return              message_type = message["messageType"]              data = message["data"] -            await ideProtocolServer.handle_json(message_type, data) +            create_async_task( +                ideProtocolServer.handle_json(message_type, data)) + +        ideProtocolServer = IdeProtocolServer(session_manager, websocket) +        other_msgs = await ideProtocolServer.initialize() + +        for other_msg in other_msgs: +            handle_msg(other_msg) + +        while AppStatus.should_exit is False: +            message = await websocket.receive_text() +            handle_msg(message)          print("Closing ide websocket") -        await websocket.close() +    except WebSocketDisconnect as e: +        print("IDE wbsocket disconnected")      except Exception as e:          print("Error in ide websocket: ", e)          capture_event(ideProtocolServer.unique_id, "gui_error", {                        "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) -        await websocket.close()          raise e +    finally: +        if websocket.client_state != WebSocketState.DISCONNECTED: +            await websocket.close() diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index dfdca504..69cb6c10 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -16,10 +16,6 @@ class AbstractIdeProtocolServer(ABC):          """Show a suggestion to the user"""      @abstractmethod -    async def getWorkspaceDirectory(self): -        """Get the workspace directory""" - -    @abstractmethod      async def setFileOpen(self, filepath: str, open: bool = True):          """Set whether a file is open""" @@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC):      async def showDiff(self, filepath: str, replacement: str, step_index: int):          """Show a diff""" -    @abstractproperty -    def workspace_directory(self) -> str: -        """Get the workspace directory""" +    workspace_directory: str      @abstractproperty      def unique_id(self) -> str: diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 873a379e..7147dcfa 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -9,11 +9,13 @@ from ..core.main import FullState  from ..core.autopilot import Autopilot  from .ide_protocol import AbstractIdeProtocolServer  from ..libs.util.create_async_task import create_async_task +from ..libs.util.errors import SessionNotFound  class Session:      session_id: str      autopilot: Autopilot +    # The GUI websocket for the session      ws: Union[WebSocket, None]      def __init__(self, session_id: str, autopilot: Autopilot): @@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot):  class SessionManager:      sessions: Dict[str, Session] = {} -    _event_loop: Union[BaseEventLoop, None] = None      def get_session(self, session_id: str) -> Session:          if session_id not in self.sessions: @@ -67,6 +68,8 @@ class SessionManager:          print("Registered websocket for session", session_id)      async def send_ws_data(self, session_id: str, message_type: str, data: Any): +        if session_id not in self.sessions: +            raise SessionNotFound(f"Session {session_id} not found")          if self.sessions[session_id].ws is None:              print(f"Session {session_id} has no websocket")              return diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py deleted file mode 100644 index c9bd760b..00000000 --- a/continuedev/src/continuedev/server/state_manager.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, List, Tuple, Union -from fastapi import WebSocket -from pydantic import BaseModel -from ..core.main import FullState - -# State updates represented as (path, replacement) pairs -StateUpdate = Tuple[List[Union[str, int]], Any] - - -class StateManager: -    """ -    A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. -    """ - -    def __init__(self, ws: WebSocket): -        self.ws = ws - -    def _send_update(self, updates: List[StateUpdate]): -        self.ws.send_json( -            [update.dict() for update in updates] -        ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index e1e041d0..14a1cd41 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -27,16 +27,21 @@ class SimpleChatStep(Step):      async def run(self, sdk: ContinueSDK):          completion = ""          messages = self.messages or await sdk.get_chat_context() -        async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5): -            if sdk.current_step_was_deleted(): -                # So that the message doesn't disappear -                self.hide = False -                return - -            if "content" in chunk: -                self.description += chunk["content"] -                completion += chunk["content"] -                await sdk.update_ui() + +        generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) +        try: +            async for chunk in generator: +                if sdk.current_step_was_deleted(): +                    # So that the message doesn't disappear +                    self.hide = False +                    return + +                if "content" in chunk: +                    self.description += chunk["content"] +                    completion += chunk["content"] +                    await sdk.update_ui() +        finally: +            await generator.aclose()          self.name = (await sdk.models.gpt35.complete(              f"Write a short title for the following chat message: {self.description}")).strip() | 
