From 55ff206a9c24e67cd1f32ec86d705206530e668f Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 10 Jul 2023 19:31:48 -0700 Subject: catch errors during asyncio tasks --- .../src/continuedev/libs/util/create_async_task.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 continuedev/src/continuedev/libs/util/create_async_task.py (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py new file mode 100644 index 00000000..89b84868 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -0,0 +1,23 @@ +from typing import Coroutine, Union +import traceback +from .telemetry import capture_event +import asyncio +import nest_asyncio +nest_asyncio.apply() + + +def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): + """asyncio.create_task and log errors by adding a callback""" + task = asyncio.create_task(coro) + + def callback(future: asyncio.Future): + try: + future.result() + except Exception as e: + print("Exception caught from async task: ", e) + capture_event(unique_id or "None", "async_task_error", { + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format(e.__traceback__) + }) + + task.add_done_callback(callback) + return task -- cgit v1.2.3-70-g09d2 From 20797dbdcdbcfc75d4eaa7cdfccfd384ecbdecd8 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 10 Jul 2023 23:12:21 -0700 Subject: traceback.format( -> format_tb --- continuedev/src/continuedev/libs/util/create_async_task.py | 2 +- continuedev/src/continuedev/server/gui.py | 3 ++- continuedev/src/continuedev/server/ide.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 89b84868..608d4977 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -16,7 +16,7 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): except Exception as e: print("Exception caught from async task: ", e) capture_event(unique_id or "None", "async_task_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format(e.__traceback__) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__) }) task.add_done_callback(callback) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 8daad73a..ae53be00 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -2,6 +2,7 @@ import json from fastapi import Depends, Header, WebSocket, APIRouter from typing import Any, List, Type, TypeVar, Union from pydantic import BaseModel +import traceback from uvicorn.main import Server from .session_manager import SessionManager, session_manager, Session @@ -188,7 +189,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we except Exception as e: print("ERROR in gui websocket: ", e) capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": e.__traceback__}) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__)}) raise e finally: print("Closing gui websocket") diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index f84f3de2..93996edd 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Type, TypeVar, Union import uuid from fastapi import WebSocket, Body, APIRouter from uvicorn.main import Server +import traceback from ..libs.util.telemetry import capture_event from ..libs.util.queue import AsyncSubscriptionQueue @@ -413,6 +414,6 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print("Error in ide websocket: ", e) capture_event(ideProtocolServer.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": e.__traceback__}) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__)}) await websocket.close() raise e -- cgit v1.2.3-70-g09d2 From 7780bbed2ab9dfa6d7152d9b63f140984dfaee0c Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Tue, 11 Jul 2023 15:02:45 -0700 Subject: correctly format traceback when logging errors --- continuedev/src/continuedev/core/autopilot.py | 2 +- continuedev/src/continuedev/libs/util/create_async_task.py | 2 +- continuedev/src/continuedev/server/gui.py | 2 +- continuedev/src/continuedev/server/ide.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 615e7657..11fc9cb1 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -285,7 +285,7 @@ class Autopilot(ContinueBaseModel): e.__class__, ContinueCustomException) error_string = e.message if is_continue_custom_exception else '\n'.join( - traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}" + traceback.format_exception(e)) error_title = e.title if is_continue_custom_exception else get_error_title( e) diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 608d4977..f41f642e 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -16,7 +16,7 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): except Exception as e: print("Exception caught from async task: ", e) capture_event(unique_id or "None", "async_task_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e) }) task.add_done_callback(callback) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index ae53be00..1321c27f 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -189,7 +189,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we except Exception as e: print("ERROR in gui websocket: ", e) capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__)}) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e)}) raise e finally: print("Closing gui websocket") diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 93996edd..fd9faec1 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -414,6 +414,6 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print("Error in ide websocket: ", e) capture_event(ideProtocolServer.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_tb(e.__traceback__)}) + "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e)}) await websocket.close() raise e -- cgit v1.2.3-70-g09d2 From f775ed5678109c6ed8b07722ddf0599d5df0dd25 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Tue, 11 Jul 2023 15:04:20 -0700 Subject: logging as string, not array of frames --- continuedev/src/continuedev/libs/util/create_async_task.py | 2 +- continuedev/src/continuedev/server/gui.py | 2 +- continuedev/src/continuedev/server/ide.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index f41f642e..62ff30ec 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -16,7 +16,7 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): except Exception as e: print("Exception caught from async task: ", e) capture_event(unique_id or "None", "async_task_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e) + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) }) task.add_done_callback(callback) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 1321c27f..dbc063c8 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -189,7 +189,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we except Exception as e: print("ERROR in gui websocket: ", e) capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e)}) + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) raise e finally: print("Closing gui websocket") diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index fd9faec1..782c2ba6 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -414,6 +414,6 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print("Error in ide websocket: ", e) capture_event(ideProtocolServer.unique_id, "gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": traceback.format_exception(e)}) + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) await websocket.close() raise e -- cgit v1.2.3-70-g09d2 From 5593c72720dec57f8b734d442e7792fb30632d90 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Wed, 12 Jul 2023 00:32:51 -0700 Subject: finally found the snag --- continuedev/src/continuedev/core/abstract_sdk.py | 4 +- continuedev/src/continuedev/core/autopilot.py | 5 +- continuedev/src/continuedev/core/sdk.py | 26 +++++---- .../src/continuedev/libs/util/create_async_task.py | 3 +- continuedev/src/continuedev/libs/util/errors.py | 2 + continuedev/src/continuedev/server/gui.py | 11 +++- continuedev/src/continuedev/server/ide.py | 67 ++++++++++++++++------ continuedev/src/continuedev/server/ide_protocol.py | 8 +-- .../src/continuedev/server/session_manager.py | 5 +- .../src/continuedev/server/state_manager.py | 21 ------- continuedev/src/continuedev/steps/chat.py | 25 ++++---- 11 files changed, 99 insertions(+), 78 deletions(-) create mode 100644 continuedev/src/continuedev/libs/util/errors.py delete mode 100644 continuedev/src/continuedev/server/state_manager.py (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 7bd3da6c..94d7be10 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC): async def get_user_secret(self, env_var: str, prompt: str) -> str: pass - @abstractproperty - def config(self) -> ContinueConfig: - pass + config: ContinueConfig @abstractmethod def set_loading_message(self, message: str): diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3f07e270..ac00e4f0 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -251,7 +251,7 @@ class Autopilot(ContinueBaseModel): # i -= 1 capture_event(self.continue_sdk.ide.unique_id, 'step run', { - 'step_name': step.name, 'params': step.dict()}) + 'step_name': step.name, 'params': step.dict()}) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep @@ -290,7 +290,8 @@ class Autopilot(ContinueBaseModel): e) # Attach an InternalErrorObservation to the step and unhide it. - print(f"Error while running step: \n{error_string}\n{error_title}") + print( + f"Error while running step: \n{error_string}\n{error_title}") capture_event(self.continue_sdk.ide.unique_id, 'step error', { 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 8649cd58..a3441ad9 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -89,6 +89,20 @@ class ContinueSDK(AbstractContinueSDK): self.__autopilot = autopilot self.models = Models(self) self.context = autopilot.context + self.config = self._load_config() + + config: ContinueConfig + + def _load_config(self) -> ContinueConfig: + dir = self.ide.workspace_directory + yaml_path = os.path.join(dir, '.continue', 'config.yaml') + json_path = os.path.join(dir, '.continue', 'config.json') + if os.path.exists(yaml_path): + return load_config(yaml_path) + elif os.path.exists(json_path): + return load_config(json_path) + else: + return load_global_config() @property def history(self) -> History: @@ -166,18 +180,6 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) - @property - def config(self) -> ContinueConfig: - dir = self.ide.workspace_directory - yaml_path = os.path.join(dir, '.continue', 'config.yaml') - json_path = os.path.join(dir, '.continue', 'config.json') - if os.path.exists(yaml_path): - return load_config(yaml_path) - elif os.path.exists(json_path): - return load_config(json_path) - else: - return load_global_config() - def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) ) if only_editing else self.__autopilot._highlighted_ranges diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 62ff30ec..354cea82 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -14,7 +14,8 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): try: future.result() except Exception as e: - print("Exception caught from async task: ", e) + print("Exception caught from async task: ", + '\n'.join(traceback.format_exception(e))) capture_event(unique_id or "None", "async_task_error", { "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) }) diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py new file mode 100644 index 00000000..46074cfc --- /dev/null +++ b/continuedev/src/continuedev/libs/util/errors.py @@ -0,0 +1,2 @@ +class SessionNotFound(Exception): + pass diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index dbc063c8..21089f30 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,5 +1,6 @@ import json from fastapi import Depends, Header, WebSocket, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from typing import Any, List, Type, TypeVar, Union from pydantic import BaseModel import traceback @@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.session = session async def _send_json(self, message_type: str, data: Any): + if self.websocket.client_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data @@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.websocket = websocket # Update any history that may have happened before connection - await protocol.send_state_update() + # await protocol.send_state_update() while AppStatus.should_exit is False: message = await websocket.receive_text() @@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we data = message["data"] protocol.handle_json(message_type, data) - + except WebSocketDisconnect as e: + print("GUI websocket disconnected") except Exception as e: print("ERROR in gui websocket: ", e) capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { @@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we raise e finally: print("Closing gui websocket") - await websocket.close() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 782c2ba6..400ad740 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -5,6 +5,7 @@ import os from typing import Any, Dict, List, Type, TypeVar, Union import uuid from fastapi import WebSocket, Body, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from uvicorn.main import Server import traceback @@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager from .ide_protocol import AbstractIdeProtocolServer import asyncio from ..libs.util.create_async_task import create_async_task +import nest_asyncio +nest_asyncio.apply() router = APIRouter(prefix="/ide", tags=["ide"]) @@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.websocket = websocket self.session_manager = session_manager + workspace_directory: str + + async def initialize(self) -> List[str]: + await self._send_json("workspaceDirectory", {}) + other_msgs = [] + while True: + msg_string = await self.websocket.receive_text() + message = json.loads(msg_string) + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + # if message_type == "openGUI": + # await self.openGUI() + if message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] + break + else: + other_msgs.append(msg_string) + return other_msgs + async def _send_json(self, message_type: str, data: Any): + if self.websocket.client_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data @@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) - elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]: + elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]: self.sub_queue.post(message_type, data) + elif message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] else: raise ValueError("Unknown message type", message_type) @@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") return resp.openFiles - async def getWorkspaceDirectory(self) -> str: - resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory") - return resp.workspaceDirectory - async def get_unique_id(self) -> str: resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId") return resp.uniqueId - @property - def workspace_directory(self) -> str: - return asyncio.run(self.getWorkspaceDirectory()) - @cached_property_no_none def unique_id(self) -> str: return asyncio.run(self.get_unique_id()) @@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket): print("Accepted websocket connection from, ", websocket.client) await websocket.send_json({"messageType": "connected", "data": {}}) - ideProtocolServer = IdeProtocolServer(session_manager, websocket) - - while AppStatus.should_exit is False: - message = await websocket.receive_text() - message = json.loads(message) + def handle_msg(msg): + message = json.loads(msg) if "messageType" not in message or "data" not in message: - continue + return message_type = message["messageType"] data = message["data"] - await ideProtocolServer.handle_json(message_type, data) + create_async_task( + ideProtocolServer.handle_json(message_type, data)) + + ideProtocolServer = IdeProtocolServer(session_manager, websocket) + other_msgs = await ideProtocolServer.initialize() + + for other_msg in other_msgs: + handle_msg(other_msg) + + while AppStatus.should_exit is False: + message = await websocket.receive_text() + handle_msg(message) print("Closing ide websocket") - await websocket.close() + except WebSocketDisconnect as e: + print("IDE wbsocket disconnected") except Exception as e: print("Error in ide websocket: ", e) capture_event(ideProtocolServer.unique_id, "gui_error", { "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) - await websocket.close() raise e + finally: + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index dfdca504..69cb6c10 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -15,10 +15,6 @@ class AbstractIdeProtocolServer(ABC): def showSuggestion(self, file_edit: FileEdit): """Show a suggestion to the user""" - @abstractmethod - async def getWorkspaceDirectory(self): - """Get the workspace directory""" - @abstractmethod async def setFileOpen(self, filepath: str, open: bool = True): """Set whether a file is open""" @@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC): async def showDiff(self, filepath: str, replacement: str, step_index: int): """Show a diff""" - @abstractproperty - def workspace_directory(self) -> str: - """Get the workspace directory""" + workspace_directory: str @abstractproperty def unique_id(self) -> str: diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 873a379e..7147dcfa 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -9,11 +9,13 @@ from ..core.main import FullState from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task +from ..libs.util.errors import SessionNotFound class Session: session_id: str autopilot: Autopilot + # The GUI websocket for the session ws: Union[WebSocket, None] def __init__(self, session_id: str, autopilot: Autopilot): @@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot): class SessionManager: sessions: Dict[str, Session] = {} - _event_loop: Union[BaseEventLoop, None] = None def get_session(self, session_id: str) -> Session: if session_id not in self.sessions: @@ -67,6 +68,8 @@ class SessionManager: print("Registered websocket for session", session_id) async def send_ws_data(self, session_id: str, message_type: str, data: Any): + if session_id not in self.sessions: + raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: print(f"Session {session_id} has no websocket") return diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py deleted file mode 100644 index c9bd760b..00000000 --- a/continuedev/src/continuedev/server/state_manager.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, List, Tuple, Union -from fastapi import WebSocket -from pydantic import BaseModel -from ..core.main import FullState - -# State updates represented as (path, replacement) pairs -StateUpdate = Tuple[List[Union[str, int]], Any] - - -class StateManager: - """ - A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. - """ - - def __init__(self, ws: WebSocket): - self.ws = ws - - def _send_update(self, updates: List[StateUpdate]): - self.ws.send_json( - [update.dict() for update in updates] - ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index e1e041d0..14a1cd41 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -27,16 +27,21 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): completion = "" messages = self.messages or await sdk.get_chat_context() - async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5): - if sdk.current_step_was_deleted(): - # So that the message doesn't disappear - self.hide = False - return - - if "content" in chunk: - self.description += chunk["content"] - completion += chunk["content"] - await sdk.update_ui() + + generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) + try: + async for chunk in generator: + if sdk.current_step_was_deleted(): + # So that the message doesn't disappear + self.hide = False + return + + if "content" in chunk: + self.description += chunk["content"] + completion += chunk["content"] + await sdk.update_ui() + finally: + await generator.aclose() self.name = (await sdk.models.gpt35.complete( f"Write a short title for the following chat message: {self.description}")).strip() -- cgit v1.2.3-70-g09d2 From 391764f1371dab06af30a29e10a826a516b69bb3 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Wed, 12 Jul 2023 21:53:06 -0700 Subject: persist state and reconnect automatically --- continuedev/src/continuedev/core/autopilot.py | 17 +++- continuedev/src/continuedev/libs/constants/main.py | 6 ++ continuedev/src/continuedev/libs/util/paths.py | 17 ++++ continuedev/src/continuedev/server/gui.py | 10 +- continuedev/src/continuedev/server/ide.py | 25 +++-- continuedev/src/continuedev/server/ide_protocol.py | 10 +- continuedev/src/continuedev/server/main.py | 16 ++- .../src/continuedev/server/session_manager.py | 41 ++++++-- docs/docs/concepts/ide.md | 4 +- extension/react-app/src/components/ComboBox.tsx | 1 - extension/react-app/src/hooks/messenger.ts | 10 -- extension/react-app/src/pages/gui.tsx | 1 - extension/src/activation/activate.ts | 2 +- extension/src/continueIdeClient.ts | 62 ++++++++--- extension/src/debugPanel.ts | 113 ++++----------------- extension/src/util/messenger.ts | 10 ++ 16 files changed, 192 insertions(+), 153 deletions(-) create mode 100644 continuedev/src/continuedev/libs/constants/main.py create mode 100644 continuedev/src/continuedev/libs/util/paths.py (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 1b074435..e1c8a076 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,13 +1,13 @@ from functools import cached_property import traceback import time -from typing import Any, Callable, Coroutine, Dict, List +from typing import Any, Callable, Coroutine, Dict, List, Union import os from aiohttp import ClientPayloadError +from pydantic import root_validator from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.llm import LLM from .observation import Observation, InternalErrorObservation from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue @@ -16,7 +16,6 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK -import asyncio from ..libs.util.step_name_to_steps import get_step_from_name from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors @@ -46,6 +45,7 @@ class Autopilot(ContinueBaseModel): ide: AbstractIdeProtocolServer history: History = History.from_empty() context: Context = Context() + full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] _active: bool = False @@ -63,8 +63,15 @@ class Autopilot(ContinueBaseModel): arbitrary_types_allowed = True keep_untouched = (cached_property,) + @root_validator(pre=True) + def fill_in_values(cls, values): + full_state: FullState = values.get('full_state') + if full_state is not None: + values['history'] = full_state.history + return values + def get_full_state(self) -> FullState: - return FullState( + full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, @@ -73,6 +80,8 @@ class Autopilot(ContinueBaseModel): slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self._adding_highlighted_code, ) + self.full_state = full_state + return full_state def get_available_slash_commands(self) -> List[Dict]: custom_commands = list(map(lambda x: { diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py new file mode 100644 index 00000000..96eb6e69 --- /dev/null +++ b/continuedev/src/continuedev/libs/constants/main.py @@ -0,0 +1,6 @@ +## PATHS ## + +CONTINUE_GLOBAL_FOLDER = ".continue" +CONTINUE_SESSIONS_FOLDER = "sessions" +CONTINUE_SERVER_FOLDER = "server" + diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py new file mode 100644 index 00000000..fddef887 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -0,0 +1,17 @@ +import os + +from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER + +def getGlobalFolderPath(): + return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) + + + +def getSessionsFolderPath(): + return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) + +def getServerFolderPath(): + return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) + +def getSessionFilePath(session_id: str): + return os.path.join(getSessionsFolderPath(), f"{session_id}.json") \ No newline at end of file diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 21089f30..8f6f68f6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -31,12 +31,12 @@ class AppStatus: Server.handle_exit = AppStatus.handle_exit -def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return session_manager.get_session(x_continue_session_id) +async def session(x_continue_session_id: str = Header("anonymous")) -> Session: + return await session_manager.get_session(x_continue_session_id) -def websocket_session(session_id: str) -> Session: - return session_manager.get_session(session_id) +async def websocket_session(session_id: str) -> Session: + return await session_manager.get_session(session_id) T = TypeVar("T", bound=BaseModel) @@ -199,4 +199,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we print("Closing gui websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() + + session_manager.persist_session(session.session_id) session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 4645b49e..12a21f19 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -52,9 +52,11 @@ class FileEditsUpdate(BaseModel): class OpenFilesResponse(BaseModel): openFiles: List[str] + class VisibleFilesResponse(BaseModel): visibleFiles: List[str] + class HighlightedCodeResponse(BaseModel): highlightedCode: List[RangeInFile] @@ -115,6 +117,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): websocket: WebSocket session_manager: SessionManager sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() + session_id: Union[str, None] = None def __init__(self, session_manager: SessionManager, websocket: WebSocket): self.websocket = websocket @@ -132,8 +135,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer): continue message_type = message["messageType"] data = message["data"] - # if message_type == "openGUI": - # await self.openGUI() if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] break @@ -158,8 +159,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): return resp_model.parse_obj(resp) async def handle_json(self, message_type: str, data: Any): - if message_type == "openGUI": - await self.openGUI() + if message_type == "getSessionId": + await self.getSessionId() elif message_type == "setFileOpen": await self.setFileOpen(data["filepath"], data["open"]) elif message_type == "setSuggestionsLocked": @@ -217,9 +218,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "locked": locked }) - async def openGUI(self): - session_id = self.session_manager.new_session(self) - await self._send_json("openGUI", { + async def getSessionId(self): + session_id = self.session_manager.new_session( + self, self.session_id).session_id + await self._send_json("getSessionId", { "sessionId": session_id }) @@ -304,7 +306,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def getOpenFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") return resp.openFiles - + async def getVisibleFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles") return resp.visibleFiles @@ -416,7 +418,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: await websocket.accept() print("Accepted websocket connection from, ", websocket.client) @@ -434,6 +436,9 @@ async def websocket_endpoint(websocket: WebSocket): ideProtocolServer.handle_json(message_type, data)) ideProtocolServer = IdeProtocolServer(session_manager, websocket) + ideProtocolServer.session_id = session_id + if session_id is not None: + session_manager.registered_ides[session_id] = ideProtocolServer other_msgs = await ideProtocolServer.initialize() for other_msg in other_msgs: @@ -454,3 +459,5 @@ async def websocket_endpoint(websocket: WebSocket): finally: if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() + + session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 2783dc61..2f78cf0e 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, List, Union from abc import ABC, abstractmethod, abstractproperty +from fastapi import WebSocket from ..models.main import Traceback from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff @@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents class AbstractIdeProtocolServer(ABC): + websocket: WebSocket + session_id: Union[str, None] + @abstractmethod async def handle_json(self, data: Any): """Handle a json message""" @@ -24,8 +28,8 @@ class AbstractIdeProtocolServer(ABC): """Set whether suggestions are locked""" @abstractmethod - async def openGUI(self): - """Open a GUI""" + async def getSessionId(self): + """Get a new session ID""" @abstractmethod async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f4d82903..aa093853 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -4,7 +4,8 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router from .gui import router as gui_router -import logging +from .session_manager import session_manager +import atexit import uvicorn import argparse @@ -44,5 +45,16 @@ def run_server(): uvicorn.run(app, host="0.0.0.0", port=args.port) +def cleanup(): + print("Cleaning up sessions") + for session_id in session_manager.sessions: + session_manager.persist_session(session_id) + + +atexit.register(cleanup) if __name__ == "__main__": - run_server() + try: + run_server() + except Exception as e: + cleanup() + raise e diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 7147dcfa..fb8ac386 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,9 +1,12 @@ -from asyncio import BaseEventLoop +import os from fastapi import WebSocket from typing import Any, Dict, List, Union from uuid import uuid4 +import json +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents +from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER from ..core.policy import DemoPolicy from ..core.main import FullState from ..core.autopilot import Autopilot @@ -39,17 +42,35 @@ class DemoAutopilot(Autopilot): class SessionManager: sessions: Dict[str, Session] = {} + # Mapping of session_id to IDE, where the IDE is still alive + registered_ides: Dict[str, AbstractIdeProtocolServer] = {} - def get_session(self, session_id: str) -> Session: + async def get_session(self, session_id: str) -> Session: if session_id not in self.sessions: + # Check then whether it is persisted by listing all files in the sessions folder + # And only if the IDE is still alive + sessions_folder = getSessionsFolderPath() + session_files = os.listdir(sessions_folder) + if f"{session_id}.json" in session_files and session_id in self.registered_ides: + if self.registered_ides[session_id].session_id is not None: + return self.new_session(self.registered_ides[session_id], session_id=session_id) + raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - def new_session(self, ide: AbstractIdeProtocolServer) -> str: - autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide) - session_id = str(uuid4()) + def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + full_state = None + if session_id is not None and os.path.exists(getSessionFilePath(session_id)): + with open(getSessionFilePath(session_id), "r") as f: + full_state = FullState(**json.load(f)) + + autopilot = DemoAutopilot( + policy=DemoPolicy(), ide=ide, full_state=full_state) + session_id = session_id or str(uuid4()) + ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) self.sessions[session_id] = session + self.registered_ides[session_id] = ide async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { @@ -58,11 +79,19 @@ class SessionManager: autopilot.on_update(on_update) create_async_task(autopilot.run_policy()) - return session_id + return session def remove_session(self, session_id: str): del self.sessions[session_id] + def persist_session(self, session_id: str): + """Save the session's FullState as a json file""" + full_state = self.sessions[session_id].autopilot.get_full_state() + if not os.path.exists(getSessionsFolderPath()): + os.mkdir(getSessionsFolderPath()) + with open(getSessionFilePath(session_id), "w") as f: + json.dump(full_state.dict(), f) + def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws print("Registered websocket for session", session_id) diff --git a/docs/docs/concepts/ide.md b/docs/docs/concepts/ide.md index dc7b9e23..bd31481b 100644 --- a/docs/docs/concepts/ide.md +++ b/docs/docs/concepts/ide.md @@ -41,9 +41,9 @@ Get the workspace directory Set whether a file is open -### openGUI +### getSessionId -Open a gui +Get a new session ID ### showSuggestionsAndWait diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index ac994b0a..7d6541c7 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -331,7 +331,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { ) { (event.nativeEvent as any).preventDownshiftDefault = true; } else if (event.key === "ArrowUp") { - console.log("OWJFOIJO"); if (positionInHistory == 0) return; else if ( positionInHistory == history.length && diff --git a/extension/react-app/src/hooks/messenger.ts b/extension/react-app/src/hooks/messenger.ts index e2a0bab8..00ce1fbb 100644 --- a/extension/react-app/src/hooks/messenger.ts +++ b/extension/react-app/src/hooks/messenger.ts @@ -1,6 +1,3 @@ -// console.log("Websocket import"); -// const WebSocket = require("ws"); - export abstract class Messenger { abstract send(messageType: string, data: object): void; @@ -28,13 +25,6 @@ export class WebsocketMessenger extends Messenger { private serverUrl: string; _newWebsocket(): WebSocket { - // // Dynamic import, because WebSocket is builtin with browser, but not with node. And can't use require in browser. - // if (typeof process === "object") { - // console.log("Using node"); - // // process is only available in Node - // var WebSocket = require("ws"); - // } - const newWebsocket = new WebSocket(this.serverUrl); for (const listener of this.onOpenListeners) { this.onOpen(listener); diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index b9382bd1..4ff260fa 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -262,7 +262,6 @@ function GUI(props: GUIProps) { const onStepUserInput = (input: string, index: number) => { if (!client) return; - console.log("Sending step user input", input, index); client.sendStepUserInput(input, index); }; diff --git a/extension/src/activation/activate.ts b/extension/src/activation/activate.ts index 559caf44..cd885b12 100644 --- a/extension/src/activation/activate.ts +++ b/extension/src/activation/activate.ts @@ -56,7 +56,7 @@ export async function activateExtension(context: vscode.ExtensionContext) { registerAllCodeLensProviders(context); registerAllCommands(context); - // Initialize IDE Protocol Client, then call "openGUI" + // Initialize IDE Protocol Client const serverUrl = getContinueServerUrl(); ideProtocolClient = new IdeProtocolClient( `${serverUrl.replace("http", "ws")}/ide/ws`, diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts index b728833f..4c1fdf1e 100644 --- a/extension/src/continueIdeClient.ts +++ b/extension/src/continueIdeClient.ts @@ -1,10 +1,9 @@ -// import { ShowSuggestionRequest } from "../schema/ShowSuggestionRequest"; import { editorSuggestionsLocked, showSuggestion as showSuggestionInEditor, SuggestionRanges, } from "./suggestions"; -import { openEditorAndRevealRange, getRightViewColumn } from "./util/vscode"; +import { openEditorAndRevealRange } from "./util/vscode"; import { FileEdit } from "../schema/FileEdit"; import { RangeInFile } from "../schema/RangeInFile"; import * as vscode from "vscode"; @@ -15,8 +14,6 @@ import { import { FileEditWithFullContents } from "../schema/FileEditWithFullContents"; import fs = require("fs"); import { WebsocketMessenger } from "./util/messenger"; -import * as path from "path"; -import * as os from "os"; import { diffManager } from "./diffs"; class IdeProtocolClient { @@ -27,17 +24,54 @@ class IdeProtocolClient { private _highlightDebounce: NodeJS.Timeout | null = null; - constructor(serverUrl: string, context: vscode.ExtensionContext) { - this.context = context; + private _lastReloadTime: number = 16; + private _reconnectionTimeouts: NodeJS.Timeout[] = []; + + private _sessionId: string | null = null; + private _serverUrl: string; - let messenger = new WebsocketMessenger(serverUrl); + private _newWebsocketMessenger() { + const requestUrl = + this._serverUrl + + (this._sessionId ? `?session_id=${this._sessionId}` : ""); + const messenger = new WebsocketMessenger(requestUrl); this.messenger = messenger; - messenger.onClose(() => { + + const reconnect = () => { + console.log("Trying to reconnect IDE protocol websocket..."); this.messenger = null; + + // Exponential backoff to reconnect + this._reconnectionTimeouts.forEach((to) => clearTimeout(to)); + + const timeout = setTimeout(() => { + if (this.messenger?.websocket?.readyState === 1) { + return; + } + this._newWebsocketMessenger(); + }, this._lastReloadTime); + + this._reconnectionTimeouts.push(timeout); + this._lastReloadTime = Math.min(2 * this._lastReloadTime, 5000); + }; + messenger.onOpen(() => { + this._reconnectionTimeouts.forEach((to) => clearTimeout(to)); + }); + messenger.onClose(() => { + reconnect(); + }); + messenger.onError(() => { + reconnect(); }); messenger.onMessage((messageType, data, messenger) => { this.handleMessage(messageType, data, messenger); }); + } + + constructor(serverUrl: string, context: vscode.ExtensionContext) { + this.context = context; + this._serverUrl = serverUrl; + this._newWebsocketMessenger(); // Setup listeners for any file changes in open editors // vscode.workspace.onDidChangeTextDocument((event) => { @@ -171,7 +205,7 @@ class IdeProtocolClient { case "showDiff": this.showDiff(data.filepath, data.replacement, data.step_index); break; - case "openGUI": + case "getSessionId": case "connected": break; default: @@ -284,10 +318,6 @@ class IdeProtocolClient { // ------------------------------------ // // Initiate Request - async openGUI(asRightWebviewPanel: boolean = false) { - // Open the webview panel - } - async getSessionId(): Promise { await new Promise((resolve, reject) => { // Repeatedly try to connect to the server @@ -303,10 +333,10 @@ class IdeProtocolClient { } }, 1000); }); - const resp = await this.messenger?.sendAndReceive("openGUI", {}); - const sessionId = resp.sessionId; + const resp = await this.messenger?.sendAndReceive("getSessionId", {}); // console.log("New Continue session with ID: ", sessionId); - return sessionId; + this._sessionId = resp.sessionId; + return resp.sessionId; } acceptRejectSuggestion(accept: boolean, key: SuggestionRanges) { diff --git a/extension/src/debugPanel.ts b/extension/src/debugPanel.ts index 487bbedf..5e1689d1 100644 --- a/extension/src/debugPanel.ts +++ b/extension/src/debugPanel.ts @@ -8,76 +8,6 @@ import { import { RangeInFile } from "./client"; const WebSocket = require("ws"); -class StreamManager { - private _fullText: string = ""; - private _insertionPoint: vscode.Position | undefined; - - private _addToEditor(update: string) { - let editor = - vscode.window.activeTextEditor || vscode.window.visibleTextEditors[0]; - - if (typeof this._insertionPoint === "undefined") { - if (editor?.selection.isEmpty) { - this._insertionPoint = editor?.selection.active; - } else { - this._insertionPoint = editor?.selection.end; - } - } - editor?.edit((editBuilder) => { - if (this._insertionPoint) { - editBuilder.insert(this._insertionPoint, update); - this._insertionPoint = this._insertionPoint.translate( - Array.from(update.matchAll(/\n/g)).length, - update.length - ); - } - }); - } - - public closeStream() { - this._fullText = ""; - this._insertionPoint = undefined; - this._codeBlockStatus = "closed"; - this._pendingBackticks = 0; - } - - private _codeBlockStatus: "open" | "closed" | "language-descriptor" = - "closed"; - private _pendingBackticks: number = 0; - public onStreamUpdate(update: string) { - let textToInsert = ""; - for (let i = 0; i < update.length; i++) { - switch (this._codeBlockStatus) { - case "closed": - if (update[i] === "`" && this._fullText.endsWith("``")) { - this._codeBlockStatus = "language-descriptor"; - } - break; - case "language-descriptor": - if (update[i] === " " || update[i] === "\n") { - this._codeBlockStatus = "open"; - } - break; - case "open": - if (update[i] === "`") { - if (this._fullText.endsWith("``")) { - this._codeBlockStatus = "closed"; - this._pendingBackticks = 0; - } else { - this._pendingBackticks += 1; - } - } else { - textToInsert += "`".repeat(this._pendingBackticks) + update[i]; - this._pendingBackticks = 0; - } - break; - } - this._fullText += update[i]; - } - this._addToEditor(textToInsert); - } -} - let websocketConnections: { [url: string]: WebsocketConnection | undefined } = {}; @@ -127,8 +57,6 @@ class WebsocketConnection { } } -let streamManager = new StreamManager(); - export let debugPanelWebview: vscode.Webview | undefined; export function setupDebugPanel( panel: vscode.WebviewPanel | vscode.WebviewView, @@ -147,10 +75,7 @@ export function setupDebugPanel( .toString(); const isProduction = true; // context?.extensionMode === vscode.ExtensionMode.Development; - if (!isProduction) { - scriptUri = "http://localhost:5173/src/main.tsx"; - styleMainUri = "http://localhost:5173/src/main.css"; - } else { + if (isProduction) { scriptUri = debugPanelWebview .asWebviewUri( vscode.Uri.joinPath(extensionUri, "react-app/dist/assets/index.js") @@ -161,6 +86,9 @@ export function setupDebugPanel( vscode.Uri.joinPath(extensionUri, "react-app/dist/assets/index.css") ) .toString(); + } else { + scriptUri = "http://localhost:5173/src/main.tsx"; + styleMainUri = "http://localhost:5173/src/main.css"; } panel.webview.options = { @@ -175,11 +103,11 @@ export function setupDebugPanel( return; } - let rangeInFile: RangeInFile = { + const rangeInFile: RangeInFile = { range: e.selections[0], filepath: e.textEditor.document.fileName, }; - let filesystem = { + const filesystem = { [rangeInFile.filepath]: e.textEditor.document.getText(), }; panel.webview.postMessage({ @@ -217,13 +145,19 @@ export function setupDebugPanel( url, }); }; - const connection = new WebsocketConnection( - url, - onMessage, - onOpen, - onClose - ); - websocketConnections[url] = connection; + try { + const connection = new WebsocketConnection( + url, + onMessage, + onOpen, + onClose + ); + websocketConnections[url] = connection; + resolve(null); + } catch (e) { + console.log("Caught it!: ", e); + reject(e); + } }); } @@ -292,15 +226,6 @@ export function setupDebugPanel( openEditorAndRevealRange(data.path, undefined, vscode.ViewColumn.One); break; } - case "streamUpdate": { - // Write code at the position of the cursor - streamManager.onStreamUpdate(data.update); - break; - } - case "closeStream": { - streamManager.closeStream(); - break; - } case "withProgress": { // This message allows withProgress to be used in the webview if (data.done) { diff --git a/extension/src/util/messenger.ts b/extension/src/util/messenger.ts index b1df161b..7fd71ddd 100644 --- a/extension/src/util/messenger.ts +++ b/extension/src/util/messenger.ts @@ -15,6 +15,8 @@ export abstract class Messenger { abstract onOpen(callback: () => void): void; abstract onClose(callback: () => void): void; + + abstract onError(callback: () => void): void; abstract sendAndReceive(messageType: string, data: any): Promise; } @@ -26,6 +28,7 @@ export class WebsocketMessenger extends Messenger { } = {}; private onOpenListeners: (() => void)[] = []; private onCloseListeners: (() => void)[] = []; + private onErrorListeners: (() => void)[] = []; private serverUrl: string; _newWebsocket(): WebSocket { @@ -43,6 +46,9 @@ export class WebsocketMessenger extends Messenger { for (const listener of this.onCloseListeners) { this.onClose(listener); } + for (const listener of this.onErrorListeners) { + this.onError(listener); + } for (const messageType in this.onMessageListeners) { for (const listener of this.onMessageListeners[messageType]) { this.onMessageType(messageType, listener); @@ -151,4 +157,8 @@ export class WebsocketMessenger extends Messenger { onClose(callback: () => void): void { this.websocket.addEventListener("close", callback); } + + onError(callback: () => void): void { + this.websocket.addEventListener("error", callback); + } } -- cgit v1.2.3-70-g09d2 From 3430de2b46b5c65c953b24dab5c31678a7bacc20 Mon Sep 17 00:00:00 2001 From: Ty Dunn Date: Thu, 13 Jul 2023 15:36:09 -0700 Subject: adding /help command --- continuedev/src/continuedev/core/config.py | 5 ++ continuedev/src/continuedev/core/policy.py | 5 +- .../continuedev/libs/util/step_name_to_steps.py | 4 +- continuedev/src/continuedev/steps/help.py | 57 ++++++++++++++++++++++ 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 continuedev/src/continuedev/steps/help.py (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index f6167638..6e430c04 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -44,6 +44,11 @@ DEFAULT_SLASH_COMMANDS = [ description="Open the config file to create new and edit existing slash commands", step_name="OpenConfigStep", ), + SlashCommand( + name="help", + description="Ask a question like '/help what is given to the llm as context?'", + step_name="HelpStep", + ), SlashCommand( name="comment", description="Write comments for the current file or highlighted code", diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index b8363df2..59ea78b1 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -60,10 +60,7 @@ class DemoPolicy(Policy): MessageStep(name="Welcome to Continue", message=dedent("""\ - Highlight code and ask a question or give instructions - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue - - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer - - Add your own OpenAI API key to VS Code Settings with `cmd+,` - - Use slash commands when you want fine-grained control - - Past steps are included as part of the context by default""")) >> + - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> # SetupContinueWorkspaceStep() >> # CreateCodebaseIndexChroma() >> diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py index d329e110..49056c81 100644 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -13,6 +13,7 @@ from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRec from ...steps.on_traceback import DefaultOnTracebackStep from ...steps.clear_history import ClearHistoryStep from ...steps.open_config import OpenConfigStep +from ...steps.help import HelpStep # This mapping is used to convert from string in ContinueConfig json to corresponding Step class. # Used for example in slash_commands and steps_on_startup @@ -28,7 +29,8 @@ step_name_to_step_class = { "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, "DefaultOnTracebackStep": DefaultOnTracebackStep, "ClearHistoryStep": ClearHistoryStep, - "OpenConfigStep": OpenConfigStep + "OpenConfigStep": OpenConfigStep, + "HelpStep": HelpStep, } diff --git a/continuedev/src/continuedev/steps/help.py b/continuedev/src/continuedev/steps/help.py new file mode 100644 index 00000000..fdfb986f --- /dev/null +++ b/continuedev/src/continuedev/steps/help.py @@ -0,0 +1,57 @@ +from textwrap import dedent +from ..core.main import ChatMessage, Step +from ..core.sdk import ContinueSDK +from ..libs.util.telemetry import capture_event + +help = dedent("""\ + Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE. + + It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized LLM later. + + Continue can be used to... + 1. Edit chunks of code with specific instructions (e.g. "/edit migrate this digital ocean terraform file into one that works for GCP") + 2. Get answers to questions without switching windows (e.g. "how do I find running process on port 8000?") + 3. Generate files from scratch (e.g. "/edit Create a Python CLI tool that uses the posthog api to get events from DAUs") + + You tell Continue to edit a specific section of code by highlighting it. If you highlight multiple code sections, then it will only edit the one with the purple glow around it. You can switch which one has the purple glow by clicking the paint brush. + + If you don't highlight any code, then Continue will insert at the location of your cursor. + + Continue passes all of the sections of code you highlight, the code above and below the to-be edited highlighted code section, and all previous steps above input box as context to the LLM. + + You can use cmd+k (Mac) / ctrl+k (Windows) to open Continue. You can use cmd+shift+e / ctrl+shift+e to open file Explorer. You can add your own OpenAI API key to VS Code Settings with `cmd+,` + + If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues. + + If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""") + +class HelpStep(Step): + + name: str = "Help" + user_input: str + manage_own_chat_context: bool = True + description: str = "" + + async def run(self, sdk: ContinueSDK): + + question = self.user_input + + prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question} + + Information: + + {help}""") + + self.chat_context.append(ChatMessage( + role="user", + content=prompt, + summary="Help" + )) + messages = await sdk.get_chat_context() + generator = sdk.models.gpt4.stream_chat(messages) + async for chunk in generator: + if "content" in chunk: + self.description += chunk["content"] + await sdk.update_ui() + + capture_event(sdk.ide.unique_id, "help", {"question": question, "answer": self.description}) \ No newline at end of file -- cgit v1.2.3-70-g09d2 From 48e5c8001e897eb37493357087410ee8f98217fa Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sat, 15 Jul 2023 14:30:11 -0700 Subject: ctrl shortcuts on windows, load models immediately --- continuedev/src/continuedev/core/autopilot.py | 10 ++-- continuedev/src/continuedev/core/sdk.py | 59 +++++++++++++++------- .../src/continuedev/libs/llm/hf_inference_api.py | 6 ++- continuedev/src/continuedev/server/ide.py | 4 +- .../src/continuedev/server/session_manager.py | 6 +-- .../react-app/src/components/StepContainer.tsx | 2 +- extension/react-app/src/components/TextDialog.tsx | 6 ++- extension/react-app/src/pages/gui.tsx | 6 +-- extension/react-app/src/util/index.ts | 30 +++++++++++ 9 files changed, 98 insertions(+), 31 deletions(-) create mode 100644 extension/react-app/src/util/index.ts (limited to 'continuedev/src/continuedev/libs') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 82439f49..0696c360 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -50,6 +50,8 @@ class Autopilot(ContinueBaseModel): full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] + continue_sdk: ContinueSDK = None + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -57,9 +59,11 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @cached_property - def continue_sdk(self) -> ContinueSDK: - return ContinueSDK(self) + @classmethod + async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": + autopilot = cls(ide=ide, policy=policy) + autopilot.continue_sdk = await ContinueSDK.create(autopilot) + return autopilot class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index aa2d8892..d73561d2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@ import asyncio from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union import os from ..steps.core.core import DefaultModelEditCodeStep @@ -13,7 +13,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, History, Step, ChatMessage from ..steps.core.core import * from ..libs.llm.proxy_server import ProxyServer @@ -22,26 +22,46 @@ class Autopilot: pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { + "openai": "OPENAI_API_KEY", + "hf_inference_api": "HUGGING_FACE_TOKEN", + "anthropic": "ANTHROPIC_API_KEY" +} + + class Models: - def __init__(self, sdk: "ContinueSDK"): + provider_keys: Dict[ModelProvider, str] = {} + model_providers: List[ModelProvider] + + def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk + self.model_providers = model_providers + + @classmethod + async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + models = Models(sdk, with_providers) + for provider in with_providers: + if provider in MODEL_PROVIDER_TO_ENV_VAR: + env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] + models.provider_keys[provider] = await sdk.get_user_secret( + env_var, f'Please add your {env_var} to the .env file') + + return models def __load_openai_model(self, model: str) -> OpenAI: - async def load_openai_model(): - api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) - return asyncio.get_event_loop().run_until_complete(load_openai_model()) + api_key = self.provider_keys["openai"] + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, model) + return OpenAI(api_key=api_key, default_model=model) + + def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: + api_key = self.provider_keys["hf_inference_api"] + return HuggingFaceInferenceAPI(api_key=api_key, model=model) @cached_property def starcoder(self): - async def load_starcoder(): - api_key = await self.sdk.get_user_secret( - 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') - return HuggingFaceInferenceAPI(api_key=api_key) - return asyncio.get_event_loop().run_until_complete(load_starcoder()) + return self.__load_hf_inference_api_model("bigcode/starcoder") @cached_property def gpt35(self): @@ -74,7 +94,7 @@ class Models: @property def default(self): default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + return self.__model_from_name(default_model) if default_model is not None else self.gpt4 class ContinueSDK(AbstractContinueSDK): @@ -87,10 +107,15 @@ class ContinueSDK(AbstractContinueSDK): def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot - self.models = Models(self) self.context = autopilot.context self.config = self._load_config() + @classmethod + async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + sdk.models = await Models.create(sdk) + return sdk + config: ContinueConfig def _load_config(self) -> ContinueConfig: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 1586c620..803ba122 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -9,7 +9,11 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): api_key: str - model: str = "bigcode/starcoder" + model: str + + def __init__(self, api_key: str, model: str): + self.api_key = api_key + self.model = model def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 7875c94d..77b13483 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -227,8 +227,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): }) async def getSessionId(self): - session_id = self.session_manager.new_session( - self, self.session_id).session_id + session_id = (await self.session_manager.new_session( + self, self.session_id)).session_id await self._send_json("getSessionId", { "sessionId": session_id }) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index fb8ac386..6d109ca6 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -53,18 +53,18 @@ class SessionManager: session_files = os.listdir(sessions_folder) if f"{session_id}.json" in session_files and session_id in self.registered_ides: if self.registered_ides[session_id].session_id is not None: - return self.new_session(self.registered_ides[session_id], session_id=session_id) + return await self.new_session(self.registered_ides[session_id], session_id=session_id) raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = DemoAutopilot( + autopilot = await DemoAutopilot.create( policy=DemoPolicy(), ide=ide, full_state=full_state) session_id = session_id or str(uuid4()) ide.session_id = session_id diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx index 14e9b854..7f23e333 100644 --- a/extension/react-app/src/components/StepContainer.tsx +++ b/extension/react-app/src/components/StepContainer.tsx @@ -181,7 +181,7 @@ function StepContainer(props: StepContainerProps) { } className="overflow-hidden cursor-pointer" onClick={(e) => { - if (e.metaKey) { + if (isMetaEquivalentKeyPressed(e)) { props.onToggleAll(); } else { props.onToggle(); diff --git a/extension/react-app/src/components/TextDialog.tsx b/extension/react-app/src/components/TextDialog.tsx index ea5727f0..c724697d 100644 --- a/extension/react-app/src/components/TextDialog.tsx +++ b/extension/react-app/src/components/TextDialog.tsx @@ -81,7 +81,11 @@ const TextDialog = (props: { rows={10} ref={textAreaRef} onKeyDown={(e) => { - if (e.key === "Enter" && e.metaKey && textAreaRef.current) { + if ( + e.key === "Enter" && + isMetaEquivalentKeyPressed(e) && + textAreaRef.current + ) { props.onEnter(textAreaRef.current.value); setText(""); } else if (e.key === "Escape") { diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index 57cebac3..cb0404ab 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -137,12 +137,12 @@ function GUI(props: GUIProps) { useEffect(() => { const listener = (e: any) => { // Cmd + i to toggle fast model - if (e.key === "i" && e.metaKey && e.shiftKey) { + if (e.key === "i" && isMetaEquivalentKeyPressed(e) && e.shiftKey) { setUsingFastModel((prev) => !prev); // Cmd + backspace to stop currently running step } else if ( e.key === "Backspace" && - e.metaKey && + isMetaEquivalentKeyPressed(e) && typeof history?.current_index !== "undefined" && history.timeline[history.current_index]?.active ) { @@ -220,7 +220,7 @@ function GUI(props: GUIProps) { if (mainTextInputRef.current) { let input = (mainTextInputRef.current as any).inputValue; // cmd+enter to /edit - if (event?.metaKey) { + if (isMetaEquivalentKeyPressed(event)) { input = `/edit ${input}`; } (mainTextInputRef.current as any).setInputValue(""); diff --git a/extension/react-app/src/util/index.ts b/extension/react-app/src/util/index.ts new file mode 100644 index 00000000..ad711321 --- /dev/null +++ b/extension/react-app/src/util/index.ts @@ -0,0 +1,30 @@ +type Platform = "mac" | "linux" | "windows" | "unknown"; + +function getPlatform(): Platform { + const platform = window.navigator.platform.toUpperCase(); + if (platform.indexOf("MAC") >= 0) { + return "mac"; + } else if (platform.indexOf("LINUX") >= 0) { + return "linux"; + } else if (platform.indexOf("WIN") >= 0) { + return "windows"; + } else { + return "unknown"; + } +} + +function isMetaEquivalentKeyPressed(event: { + metaKey: boolean; + ctrlKey: boolean; +}): boolean { + const platform = getPlatform(); + switch (platform) { + case "mac": + return event.metaKey; + case "linux": + case "windows": + return event.ctrlKey; + default: + return event.metaKey; + } +} -- cgit v1.2.3-70-g09d2