summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-12 00:32:51 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-12 00:32:51 -0700
commita6ac1fdae55ba79e2ffbd00f88fc39971ed47f58 (patch)
tree1284916971492ed4d4ac7c91295f6529b6209eaf /continuedev
parent67d6a3f0ea00e55aea47b4eeff4cdb0d8321ce2f (diff)
downloadsncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.tar.gz
sncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.tar.bz2
sncontinue-a6ac1fdae55ba79e2ffbd00f88fc39971ed47f58.zip
finally found the snag
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py4
-rw-r--r--continuedev/src/continuedev/core/autopilot.py5
-rw-r--r--continuedev/src/continuedev/core/sdk.py26
-rw-r--r--continuedev/src/continuedev/libs/util/create_async_task.py3
-rw-r--r--continuedev/src/continuedev/libs/util/errors.py2
-rw-r--r--continuedev/src/continuedev/server/gui.py11
-rw-r--r--continuedev/src/continuedev/server/ide.py67
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py8
-rw-r--r--continuedev/src/continuedev/server/session_manager.py5
-rw-r--r--continuedev/src/continuedev/server/state_manager.py21
-rw-r--r--continuedev/src/continuedev/steps/chat.py25
11 files changed, 99 insertions, 78 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 7bd3da6c..94d7be10 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
pass
- @abstractproperty
- def config(self) -> ContinueConfig:
- pass
+ config: ContinueConfig
@abstractmethod
def set_loading_message(self, message: str):
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 3f07e270..ac00e4f0 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -251,7 +251,7 @@ class Autopilot(ContinueBaseModel):
# i -= 1
capture_event(self.continue_sdk.ide.unique_id, 'step run', {
- 'step_name': step.name, 'params': step.dict()})
+ 'step_name': step.name, 'params': step.dict()})
if not is_future_step:
# Check manual edits buffer, clear out if needed by creating a ManualEditStep
@@ -290,7 +290,8 @@ class Autopilot(ContinueBaseModel):
e)
# Attach an InternalErrorObservation to the step and unhide it.
- print(f"Error while running step: \n{error_string}\n{error_title}")
+ print(
+ f"Error while running step: \n{error_string}\n{error_title}")
capture_event(self.continue_sdk.ide.unique_id, 'step error', {
'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 8649cd58..a3441ad9 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -89,6 +89,20 @@ class ContinueSDK(AbstractContinueSDK):
self.__autopilot = autopilot
self.models = Models(self)
self.context = autopilot.context
+ self.config = self._load_config()
+
+ config: ContinueConfig
+
+ def _load_config(self) -> ContinueConfig:
+ dir = self.ide.workspace_directory
+ yaml_path = os.path.join(dir, '.continue', 'config.yaml')
+ json_path = os.path.join(dir, '.continue', 'config.json')
+ if os.path.exists(yaml_path):
+ return load_config(yaml_path)
+ elif os.path.exists(json_path):
+ return load_config(json_path)
+ else:
+ return load_global_config()
@property
def history(self) -> History:
@@ -166,18 +180,6 @@ class ContinueSDK(AbstractContinueSDK):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
- @property
- def config(self) -> ContinueConfig:
- dir = self.ide.workspace_directory
- yaml_path = os.path.join(dir, '.continue', 'config.yaml')
- json_path = os.path.join(dir, '.continue', 'config.json')
- if os.path.exists(yaml_path):
- return load_config(yaml_path)
- elif os.path.exists(json_path):
- return load_config(json_path)
- else:
- return load_global_config()
-
def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)
) if only_editing else self.__autopilot._highlighted_ranges
diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py
index 62ff30ec..354cea82 100644
--- a/continuedev/src/continuedev/libs/util/create_async_task.py
+++ b/continuedev/src/continuedev/libs/util/create_async_task.py
@@ -14,7 +14,8 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):
try:
future.result()
except Exception as e:
- print("Exception caught from async task: ", e)
+ print("Exception caught from async task: ",
+ '\n'.join(traceback.format_exception(e)))
capture_event(unique_id or "None", "async_task_error", {
"error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))
})
diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py
new file mode 100644
index 00000000..46074cfc
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/errors.py
@@ -0,0 +1,2 @@
+class SessionNotFound(Exception):
+ pass
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index dbc063c8..21089f30 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,5 +1,6 @@
import json
from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
import traceback
@@ -52,6 +53,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
@@ -171,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.websocket = websocket
# Update any history that may have happened before connection
- await protocol.send_state_update()
+ # await protocol.send_state_update()
while AppStatus.should_exit is False:
message = await websocket.receive_text()
@@ -185,7 +188,8 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
-
+ except WebSocketDisconnect as e:
+ print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
@@ -193,5 +197,6 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
raise e
finally:
print("Closing gui websocket")
- await websocket.close()
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 782c2ba6..400ad740 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,6 +5,7 @@ import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
import traceback
@@ -17,6 +18,8 @@ from .gui import SessionManager, session_manager
from .ide_protocol import AbstractIdeProtocolServer
import asyncio
from ..libs.util.create_async_task import create_async_task
+import nest_asyncio
+nest_asyncio.apply()
router = APIRouter(prefix="/ide", tags=["ide"])
@@ -115,7 +118,30 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.websocket = websocket
self.session_manager = session_manager
+ workspace_directory: str
+
+ async def initialize(self) -> List[str]:
+ await self._send_json("workspaceDirectory", {})
+ other_msgs = []
+ while True:
+ msg_string = await self.websocket.receive_text()
+ message = json.loads(msg_string)
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+ # if message_type == "openGUI":
+ # await self.openGUI()
+ if message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ break
+ else:
+ other_msgs.append(msg_string)
+ return other_msgs
+
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.client_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
@@ -154,8 +180,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]:
self.sub_queue.post(message_type, data)
+ elif message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
else:
raise ValueError("Unknown message type", message_type)
@@ -275,18 +303,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
- async def getWorkspaceDirectory(self) -> str:
- resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
- return resp.workspaceDirectory
-
async def get_unique_id(self) -> str:
resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
return resp.uniqueId
- @property
- def workspace_directory(self) -> str:
- return asyncio.run(self.getWorkspaceDirectory())
-
@cached_property_no_none
def unique_id(self) -> str:
return asyncio.run(self.get_unique_id())
@@ -396,24 +416,35 @@ async def websocket_endpoint(websocket: WebSocket):
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- message = json.loads(message)
+ def handle_msg(msg):
+ message = json.loads(msg)
if "messageType" not in message or "data" not in message:
- continue
+ return
message_type = message["messageType"]
data = message["data"]
- await ideProtocolServer.handle_json(message_type, data)
+ create_async_task(
+ ideProtocolServer.handle_json(message_type, data))
+
+ ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ other_msgs = await ideProtocolServer.initialize()
+
+ for other_msg in other_msgs:
+ handle_msg(other_msg)
+
+ while AppStatus.should_exit is False:
+ message = await websocket.receive_text()
+ handle_msg(message)
print("Closing ide websocket")
- await websocket.close()
+ except WebSocketDisconnect as e:
+ print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
capture_event(ideProtocolServer.unique_id, "gui_error", {
"error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
- await websocket.close()
raise e
+ finally:
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..69cb6c10 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -16,10 +16,6 @@ class AbstractIdeProtocolServer(ABC):
"""Show a suggestion to the user"""
@abstractmethod
- async def getWorkspaceDirectory(self):
- """Get the workspace directory"""
-
- @abstractmethod
async def setFileOpen(self, filepath: str, open: bool = True):
"""Set whether a file is open"""
@@ -103,9 +99,7 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
- @abstractproperty
- def workspace_directory(self) -> str:
- """Get the workspace directory"""
+ workspace_directory: str
@abstractproperty
def unique_id(self) -> str:
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 873a379e..7147dcfa 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -9,11 +9,13 @@ from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
class Session:
session_id: str
autopilot: Autopilot
+ # The GUI websocket for the session
ws: Union[WebSocket, None]
def __init__(self, session_id: str, autopilot: Autopilot):
@@ -37,7 +39,6 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
- _event_loop: Union[BaseEventLoop, None] = None
def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
@@ -67,6 +68,8 @@ class SessionManager:
print("Registered websocket for session", session_id)
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if session_id not in self.sessions:
+ raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
print(f"Session {session_id} has no websocket")
return
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
- """
- A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
- """
-
- def __init__(self, ws: WebSocket):
- self.ws = ws
-
- def _send_update(self, updates: List[StateUpdate]):
- self.ws.send_json(
- [update.dict() for update in updates]
- )
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index e1e041d0..14a1cd41 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -27,16 +27,21 @@ class SimpleChatStep(Step):
async def run(self, sdk: ContinueSDK):
completion = ""
messages = self.messages or await sdk.get_chat_context()
- async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5):
- if sdk.current_step_was_deleted():
- # So that the message doesn't disappear
- self.hide = False
- return
-
- if "content" in chunk:
- self.description += chunk["content"]
- completion += chunk["content"]
- await sdk.update_ui()
+
+ generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5)
+ try:
+ async for chunk in generator:
+ if sdk.current_step_was_deleted():
+ # So that the message doesn't disappear
+ self.hide = False
+ return
+
+ if "content" in chunk:
+ self.description += chunk["content"]
+ completion += chunk["content"]
+ await sdk.update_ui()
+ finally:
+ await generator.aclose()
self.name = (await sdk.models.gpt35.complete(
f"Write a short title for the following chat message: {self.description}")).strip()