diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-01 23:16:13 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-01 23:16:13 -0700 | 
| commit | dc2b90d848f5fc53a18ad481ba196ac9708de8ec (patch) | |
| tree | b8da4f4e7d31e362b39c6f9db38dec3a70afb235 /continuedev/src | |
| parent | f02b5c10876e3c0ee40d3c095564f675c9296bdf (diff) | |
| parent | 01ed2c7eb2d3417b2c190eea105008372f49a7c6 (diff) | |
| download | sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.gz sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.tar.bz2 sncontinue-dc2b90d848f5fc53a18ad481ba196ac9708de8ec.zip | |
Merge branch 'main' into package-python
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 38 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/main.py | 45 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 12 | 
7 files changed, 60 insertions, 83 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index d92c51cd..a3dd854e 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,7 +1,7 @@  from functools import cached_property  import traceback  import time -from typing import Any, Callable, Coroutine, Dict, List, Union +from typing import Callable, Coroutine, Dict, List, Union  from aiohttp import ClientPayloadError  from pydantic import root_validator @@ -54,7 +54,7 @@ class Autopilot(ContinueBaseModel):      history: History = History.from_empty()      context: Context = Context()      full_state: Union[FullState, None] = None -    context_manager: Union[ContextManager, None] = None +    context_manager: ContextManager = ContextManager()      continue_sdk: ContinueSDK = None      _on_update_callbacks: List[Callable[[FullState], None]] = [] @@ -66,19 +66,22 @@ class Autopilot(ContinueBaseModel):      _user_input_queue = AsyncSubscriptionQueue()      _retry_queue = AsyncSubscriptionQueue() +    started: bool = False +      async def start(self):          self.continue_sdk = await ContinueSDK.create(self)          if override_policy := self.continue_sdk.config.policy_override:              self.policy = override_policy          # Load documents into the search index -        self.context_manager = await ContextManager.create( +        await self.context_manager.start(              self.continue_sdk.config.context_providers + [                  HighlightedCodeContextProvider(ide=self.ide),                  FileContextProvider(workspace_dir=self.ide.workspace_directory)              ])          await self.context_manager.load_index(self.ide.workspace_directory) +        self.started = True      class Config:          arbitrary_types_allowed = True @@ -98,7 +101,7 @@ class Autopilot(ContinueBaseModel):              user_input_queue=self._main_user_input_queue,              slash_commands=self.get_available_slash_commands(),              adding_highlighted_code=self.context_manager.context_providers[ -                "code"].adding_highlighted_code if self.context_manager is not None else False, +                "code"].adding_highlighted_code if "code" in self.context_manager.context_providers else False,              selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],          )          self.full_state = full_state @@ -201,7 +204,7 @@ class Autopilot(ContinueBaseModel):          await self.update_subscribers()      async def set_editing_at_ids(self, ids: List[str]): -        self.context_manager.context_providers["code"].set_editing_at_ids(ids) +        await self.context_manager.context_providers["code"].set_editing_at_ids(ids)          await self.update_subscribers()      async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: @@ -244,7 +247,7 @@ class Autopilot(ContinueBaseModel):          try:              observation = await step(self.continue_sdk)          except Exception as e: -            if self.history.timeline[index_of_history_node].deleted: +            if index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted:                  # If step was deleted/cancelled, don't show error or allow retry                  return None @@ -301,7 +304,7 @@ class Autopilot(ContinueBaseModel):          self._step_depth -= 1          # Add observation to history, unless already attached error observation -        if not caught_error: +        if not caught_error and index_of_history_node < len(self.history.timeline):              self.history.timeline[index_of_history_node].observation = observation              self.history.timeline[index_of_history_node].active = False              await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index e968c35c..3f5f6fd3 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -133,14 +133,19 @@ class ContextManager:          """          return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) -    def __init__(self, context_providers: List[ContextProvider]): +    def __init__(self): +        self.context_providers = {} +        self.provider_titles = set() + +    async def start(self, context_providers: List[ContextProvider]): +        """ +        Starts the context manager. +        """          self.context_providers = {              prov.title: prov for prov in context_providers}          self.provider_titles = {              provider.title for provider in context_providers} -    @classmethod -    async def create(cls, context_providers: List[ContextProvider]):          async with Client('http://localhost:7700') as search_client:              meilisearch_running = True              try: @@ -154,10 +159,8 @@ class ContextManager:              if not meilisearch_running:                  logger.warning(                      "MeiliSearch not running, avoiding any dependent context providers") -                context_providers = list( -                    filter(lambda cp: cp.title == "code", context_providers)) - -        return cls(context_providers) +                self.context_providers = list( +                    filter(lambda cp: cp.title == "code", self.context_providers))      async def load_index(self, workspace_dir: str):          for _, provider in self.context_providers.items(): diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index cf18c56b..7c89c5c2 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -93,8 +93,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              self.on_delete_context_with_ids(data["ids"])          elif message_type == "toggle_adding_highlighted_code":              self.on_toggle_adding_highlighted_code() -        elif message_type == "set_editing_at_indices": -            self.on_set_editing_at_indices(data["indices"]) +        elif message_type == "set_editing_at_ids": +            self.on_set_editing_at_ids(data["ids"])          elif message_type == "show_logs_at_index":              self.on_show_logs_at_index(data["index"])          elif message_type == "select_context_item": @@ -138,9 +138,9 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          create_async_task(              self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) -    def on_set_editing_at_indices(self, indices: List[int]): +    def on_set_editing_at_ids(self, ids: List[str]):          create_async_task( -            self.session.autopilot.set_editing_at_indices(indices), self.on_error) +            self.session.autopilot.set_editing_at_ids(ids), self.on_error)      def on_show_logs_at_index(self, index: int):          name = f"continue_logs.txt" @@ -190,7 +190,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          posthog_logger.capture_event("gui_error", {              "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) -        await protocol.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) +        await session.autopilot.ide.showMessage(err_msg)          raise e      finally: diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 6124f3bd..9797a8b7 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -1,7 +1,7 @@  # This is a separate server from server/main.py  import json  import os -from typing import Any, List, Type, TypeVar, Union +from typing import Any, Coroutine, List, Type, TypeVar, Union  import uuid  from fastapi import WebSocket, APIRouter  from starlette.websockets import WebSocketState, WebSocketDisconnect @@ -211,8 +211,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          else:              raise ValueError("Unknown message type", message_type) -    # ------------------------------- # -    # Request actions in IDE, doesn't matter which Session      async def showSuggestion(self, file_edit: FileEdit):          await self._send_json("showSuggestion", {              "edit": file_edit.dict() @@ -232,6 +230,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              "open": open          }) +    async def showMessage(self, message: str): +        await self._send_json("showMessage", { +            "message": message +        }) +      async def showVirtualFile(self, name: str, contents: str):          await self._send_json("showVirtualFile", {              "name": name, @@ -275,13 +278,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          # Just need connect the suggestionId to the IDE (and the gui)          return any([r.accepted for r in responses]) -    # ------------------------------- # -    # Here needs to pass message onto the Autopilot OR Autopilot just subscribes. -    # This is where you might have triggers: plugins can subscribe to certian events -    # like file changes, tracebacks, etc... - -    def on_error(self, e: Exception): -        return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) +    def on_error(self, e: Exception) -> Coroutine: +        try: +            return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) +        except: +            err_msg = '\n'.join(traceback.format_exception(e)) +            return self.showMessage(f"Error in Continue server: {err_msg}")      def onAcceptRejectSuggestion(self, accepted: bool):          posthog_logger.capture_event("accept_reject_suggestion", { @@ -307,7 +309,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      def __get_autopilot(self):          if self.session_id not in self.session_manager.sessions:              return None -        return self.session_manager.sessions[self.session_id].autopilot + +        autopilot = self.session_manager.sessions[self.session_id].autopilot +        return autopilot if autopilot.started else None      def onFileEdits(self, edits: List[FileEditWithFullContents]):          if autopilot := self.__get_autopilot(): @@ -442,6 +446,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session_id: str = None):      try: +        # Accept the websocket connection          await websocket.accept()          logger.debug(f"Accepted websocket connection from {websocket.client}")          await websocket.send_json({"messageType": "connected", "data": {}}) @@ -453,6 +458,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):              logger.debug("Failed to start MeiliSearch")              logger.debug(e) +        # Message handler          def handle_msg(msg):              message = json.loads(msg) @@ -465,6 +471,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):              create_async_task(                  ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error) +        # Initialize the IDE Protocol Server          ideProtocolServer = IdeProtocolServer(session_manager, websocket)          if session_id is not None:              session_manager.registered_ides[session_id] = ideProtocolServer @@ -475,20 +482,23 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):          for other_msg in other_msgs:              handle_msg(other_msg) +        # Handle messages          while AppStatus.should_exit is False:              message = await websocket.receive_text()              handle_msg(message) -        logger.debug("Closing ide websocket")      except WebSocketDisconnect as e: -        logger.debug("IDE wbsocket disconnected") +        logger.debug("IDE websocket disconnected")      except Exception as e:          logger.debug(f"Error in ide websocket: {e}")          err_msg = '\n'.join(traceback.format_exception(e))          posthog_logger.capture_event("gui_error", {              "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) -        await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) +        if session_id is not None and session_id in session_manager.sessions: +            await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) +        elif ideProtocolServer is not None: +            await ideProtocolServer.showMessage(f"Error in Continue server: {err_msg}")          raise e      finally: diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 0ae7e7fa..72b410d4 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC):          """Set whether a file is open"""      @abstractmethod +    async def showMessage(self, message: str): +        """Show a message to the user""" + +    @abstractmethod      async def showVirtualFile(self, name: str, contents: str):          """Show a virtual file""" diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 468bc855..f8dfb009 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,5 +1,4 @@  import asyncio -import sys  import time  import psutil  import os @@ -8,13 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware  import atexit  import uvicorn  import argparse -import logging.config  from .ide import router as ide_router  from .gui import router as gui_router  from .session_manager import session_manager -from ..libs.util.paths import getLogFilePath  from ..libs.util.logging import logger  app = FastAPI() @@ -38,25 +35,6 @@ def health():      return {"status": "ok"} -class Logger(object): -    def __init__(self, log_file: str): -        self.terminal = sys.stdout -        self.log = open(log_file, "a") - -    def write(self, message): -        self.terminal.write(message) -        self.log.write(message) - -    def flush(self): -        # this flush method is needed for python 3 compatibility. -        # this handles the flush command by doing nothing. -        # you might want to specify some extra behavior here. -        pass - -    def isatty(self): -        return False - -  try:      # add cli arg for server port      parser = argparse.ArgumentParser() @@ -71,7 +49,6 @@ except Exception as e:  def run_server():      config = uvicorn.Config(app, host="127.0.0.1", port=args.port)      server = uvicorn.Server(config) -      server.run() @@ -87,32 +64,10 @@ def cleanup():      loop.close() -def cpu_usage_report(): -    process = psutil.Process(os.getpid()) -    # Call cpu_percent once to start measurement, but ignore the result -    process.cpu_percent(interval=None) -    # Wait for a short period of time -    time.sleep(1) -    # Call cpu_percent again to get the CPU usage over the interval -    cpu_usage = process.cpu_percent(interval=None) -    logger.debug(f"CPU usage: {cpu_usage}%") - -  atexit.register(cleanup)  if __name__ == "__main__":      try: -        # Uncomment to get CPU usage reports -        # import threading - -        # def cpu_usage_loop(): -        #     while True: -        #         cpu_usage_report() -        #         time.sleep(2) - -        # cpu_thread = threading.Thread(target=cpu_usage_loop) -        # cpu_thread.start() -          run_server()      except Exception as e:          logger.debug(f"Error starting Continue server: {e}") diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index b5580fe8..56c92307 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,15 +1,14 @@  import os  import traceback  from fastapi import WebSocket -from typing import Any, Dict, List, Union +from typing import Any, Coroutine, Dict, Union  from uuid import uuid4  import json  from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import DisplayErrorStep, MessageStep +from ..plugins.steps.core.core import MessageStep  from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath -from ..models.filesystem_edit import FileEditWithFullContents  from ..core.main import FullState, HistoryNode  from ..core.autopilot import Autopilot  from .ide_protocol import AbstractIdeProtocolServer @@ -90,8 +89,11 @@ class SessionManager:              ))              logger.warning(f"Error loading context manager: {e}") -        create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step( -            DisplayErrorStep(e=e))) +        def on_error(e: Exception) -> Coroutine: +            err_msg = '\n'.join(traceback.format_exception(e)) +            return ide.showMessage(f"Error in Continue server: {err_msg}") + +        create_async_task(autopilot.run_policy(), on_error)          return session      async def remove_session(self, session_id: str): | 
