diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-10-09 18:37:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-09 18:37:27 -0700 |
commit | f09150617ed2454f3074bcf93f53aae5ae637d40 (patch) | |
tree | 5cfe614a64d921dfe58b049f426d67a8b832c71f /continuedev/src/continuedev/server | |
parent | 985304a213f620cdff3f8f65f74ed7e3b79be29d (diff) | |
download | sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.gz sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.bz2 sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.zip |
Preview (#541)
* Strong typing (#533)
* refactor: :recycle: get rid of continuedev.src.continuedev structure
* refactor: :recycle: switching back to server folder
* feat: :sparkles: make config.py imports shorter
* feat: :bookmark: publish as pre-release vscode extension
* refactor: :recycle: refactor and add more completion params to ui
* build: :building_construction: download from preview S3
* fix: :bug: fix paths
* fix: :green_heart: package:pre-release
* ci: :green_heart: more time for tests
* fix: :green_heart: fix build scripts
* fix: :bug: fix import in run.py
* fix: :bookmark: update version to try again
* ci: 💚 Update package.json version [skip ci]
* refactor: :fire: don't check for old extensions version
* fix: :bug: small bug fixes
* fix: :bug: fix config.py import paths
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: platform-specific builds test #1
* feat: :green_heart: ship with binary
* fix: :green_heart: fix copy statement to include.exe for windows
* fix: :green_heart: cd extension before packaging
* chore: :loud_sound: count tokens generated
* fix: :green_heart: remove npm_config_arch
* fix: :green_heart: publish as pre-release!
* chore: :bookmark: update version
* perf: :green_heart: hardcode distro paths
* fix: :bug: fix yaml syntax error
* chore: :bookmark: update version
* fix: :green_heart: update permissions and version
* feat: :bug: kill old server if needed
* feat: :lipstick: update marketplace icon for pre-release
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: auto-reload for config.py
* feat: :wrench: update default config.py imports
* feat: :sparkles: codelens in config.py
* feat: :sparkles: select model param count from UI
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: more model options, ollama error handling
* perf: :zap: don't show server loading immediately
* fix: :bug: fixing small UI details
* ci: 💚 Update package.json version [skip ci]
* feat: :rocket: headers param on LLM class
* fix: :bug: fix headers for openai.;y
* feat: :sparkles: highlight code on cmd+shift+L
* ci: 💚 Update package.json version [skip ci]
* feat: :lipstick: sticky top bar in gui.tsx
* fix: :loud_sound: websocket logging and horizontal scrollbar
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: allow AzureOpenAI Service through GGML
* ci: 💚 Update package.json version [skip ci]
* fix: :bug: fix automigration
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: upload binaries in ci, download apple silicon
* chore: :fire: remove notes
* fix: :green_heart: use curl to download binary
* fix: :green_heart: set permissions on apple silicon binary
* fix: :green_heart: testing
* fix: :green_heart: cleanup file
* fix: :green_heart: fix preview.yaml
* fix: :green_heart: only upload once per binary
* fix: :green_heart: install rosetta
* ci: :green_heart: download binary after tests
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: prepare ci for merge to main
---------
Co-authored-by: GitHub Action <action@github.com>
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 453 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 673 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 170 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 109 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 195 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 192 |
6 files changed, 0 insertions, 1792 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py deleted file mode 100644 index 26fcbd42..00000000 --- a/continuedev/src/continuedev/server/gui.py +++ /dev/null @@ -1,453 +0,0 @@ -import asyncio -import json -import traceback -from typing import Any, List, Optional, Type, TypeVar - -from fastapi import APIRouter, Depends, WebSocket -from pydantic import BaseModel -from starlette.websockets import WebSocketDisconnect, WebSocketState -from uvicorn.main import Server - -from ..core.main import ContextItem -from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES -from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages -from ..libs.util.create_async_task import create_async_task -from ..libs.util.edit_config import ( - add_config_import, - create_float_node, - create_obj_node, - create_string_node, - display_llm_class, -) -from ..libs.util.logging import logger -from ..libs.util.queue import AsyncSubscriptionQueue -from ..libs.util.telemetry import posthog_logger -from ..plugins.steps.core.core import DisplayErrorStep -from ..plugins.steps.setup_model import SetupModelStep -from .session_manager import Session, session_manager - -router = APIRouter(prefix="/gui", tags=["gui"]) - -# Graceful shutdown by closing websockets -original_handler = Server.handle_exit - - -class AppStatus: - should_exit = False - - @staticmethod - def handle_exit(*args, **kwargs): - AppStatus.should_exit = True - logger.debug("Shutting down") - original_handler(*args, **kwargs) - - -Server.handle_exit = AppStatus.handle_exit - - -async def websocket_session(session_id: str) -> Session: - return await session_manager.get_session(session_id) - - -T = TypeVar("T", bound=BaseModel) - -# You should probably abstract away the websocket stuff into a separate class - - -class GUIProtocolServer: - websocket: WebSocket - session: Session - sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() - - def __init__(self, session: Session): - self.session = session - - async def _send_json(self, message_type: str, data: Any): - if self.websocket.application_state == WebSocketState.DISCONNECTED: - return - await self.websocket.send_json({"messageType": message_type, "data": data}) - - async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: - try: - return await asyncio.wait_for( - self.sub_queue.get(message_type), timeout=timeout - ) - except asyncio.TimeoutError: - raise Exception("GUI Protocol _receive_json timed out after 20 seconds") - - async def _send_and_receive_json( - self, data: Any, resp_model: Type[T], message_type: str - ) -> T: - await self._send_json(message_type, data) - resp = await self._receive_json(message_type) - return resp_model.parse_obj(resp) - - def on_error(self, e: Exception): - return self.session.autopilot.continue_sdk.run_step( - DisplayErrorStep.from_exception(e) - ) - - def handle_json(self, message_type: str, data: Any): - if message_type == "main_input": - self.on_main_input(data["input"]) - elif message_type == "step_user_input": - self.on_step_user_input(data["input"], data["index"]) - elif message_type == "refinement_input": - self.on_refinement_input(data["input"], data["index"]) - elif message_type == "reverse_to_index": - self.on_reverse_to_index(data["index"]) - elif message_type == "retry_at_index": - self.on_retry_at_index(data["index"]) - elif message_type == "clear_history": - self.on_clear_history() - elif message_type == "set_current_session_title": - self.set_current_session_title(data["title"]) - elif message_type == "delete_at_index": - self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"], data.get("index", None)) - elif message_type == "toggle_adding_highlighted_code": - self.on_toggle_adding_highlighted_code() - elif message_type == "set_editing_at_ids": - self.on_set_editing_at_ids(data["ids"]) - elif message_type == "show_logs_at_index": - self.on_show_logs_at_index(data["index"]) - elif message_type == "show_context_virtual_file": - self.show_context_virtual_file(data.get("index", None)) - elif message_type == "select_context_item": - self.select_context_item(data["id"], data["query"]) - elif message_type == "select_context_item_at_index": - self.select_context_item_at_index(data["id"], data["query"], data["index"]) - elif message_type == "load_session": - self.load_session(data.get("session_id", None)) - elif message_type == "edit_step_at_index": - self.edit_step_at_index(data.get("user_input", ""), data["index"]) - elif message_type == "set_system_message": - self.set_system_message(data["message"]) - elif message_type == "set_temperature": - self.set_temperature(float(data["temperature"])) - elif message_type == "add_model_for_role": - self.add_model_for_role(data["role"], data["model_class"], data["model"]) - elif message_type == "set_model_for_role_from_index": - self.set_model_for_role_from_index(data["role"], data["index"]) - elif message_type == "save_context_group": - self.save_context_group( - data["title"], [ContextItem(**item) for item in data["context_items"]] - ) - elif message_type == "select_context_group": - self.select_context_group(data["id"]) - elif message_type == "delete_context_group": - self.delete_context_group(data["id"]) - - def on_main_input(self, input: str): - # Do something with user input - create_async_task( - self.session.autopilot.accept_user_input(input), self.on_error - ) - - def on_reverse_to_index(self, index: int): - # Reverse the history to the given index - create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error) - - def on_step_user_input(self, input: str, index: int): - create_async_task( - self.session.autopilot.give_user_input(input, index), self.on_error - ) - - def on_refinement_input(self, input: str, index: int): - create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.on_error - ) - - def on_retry_at_index(self, index: int): - create_async_task(self.session.autopilot.retry_at_index(index), self.on_error) - - def on_clear_history(self): - create_async_task(self.session.autopilot.clear_history(), self.on_error) - - def on_delete_at_index(self, index: int): - create_async_task(self.session.autopilot.delete_at_index(index), self.on_error) - - def edit_step_at_index(self, user_input: str, index: int): - create_async_task( - self.session.autopilot.edit_step_at_index(user_input, index), - self.on_error, - ) - - def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None): - create_async_task( - self.session.autopilot.delete_context_with_ids(ids, index), self.on_error - ) - - def on_toggle_adding_highlighted_code(self): - create_async_task( - self.session.autopilot.toggle_adding_highlighted_code(), self.on_error - ) - posthog_logger.capture_event("toggle_adding_highlighted_code", {}) - - def on_set_editing_at_ids(self, ids: List[str]): - create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error) - - def on_show_logs_at_index(self, index: int): - name = "Continue Context" - logs = "\n\n############################################\n\n".join( - ["This is the prompt that was sent to the LLM during this step"] - + self.session.autopilot.continue_sdk.history.timeline[index].logs - ) - create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error - ) - posthog_logger.capture_event("show_logs_at_index", {}) - - def show_context_virtual_file(self, index: Optional[int] = None): - async def async_stuff(): - if index is None: - context_items = ( - await self.session.autopilot.context_manager.get_selected_items() - ) - elif index < len(self.session.autopilot.continue_sdk.history.timeline): - context_items = self.session.autopilot.continue_sdk.history.timeline[ - index - ].context_used - - ctx = "\n\n-----------------------------------\n\n".join( - ["These are the context items that will be passed to the LLM"] - + list(map(lambda x: x.content, context_items)) - ) - await self.session.autopilot.ide.showVirtualFile( - "Continue - Selected Context", ctx - ) - - create_async_task( - async_stuff(), - self.on_error, - ) - - def select_context_item(self, id: str, query: str): - """Called when user selects an item from the dropdown""" - create_async_task( - self.session.autopilot.select_context_item(id, query), self.on_error - ) - - def select_context_item_at_index(self, id: str, query: str, index: int): - """Called when user selects an item from the dropdown for prev UserInputStep""" - create_async_task( - self.session.autopilot.select_context_item_at_index(id, query, index), - self.on_error, - ) - - def load_session(self, session_id: Optional[str] = None): - async def load_and_tell_to_reconnect(): - new_session_id = await session_manager.load_session( - self.session.session_id, session_id - ) - await self._send_json( - "reconnect_at_session", {"session_id": new_session_id} - ) - - create_async_task(load_and_tell_to_reconnect(), self.on_error) - - posthog_logger.capture_event("load_session", {"session_id": session_id}) - - def set_current_session_title(self, title: str): - self.session.autopilot.set_current_session_title(title) - - def set_system_message(self, message: str): - self.session.autopilot.continue_sdk.config.system_message = message - self.session.autopilot.continue_sdk.models.set_system_message(message) - - create_async_task( - self.session.autopilot.set_config_attr( - ["system_message"], create_string_node(message) - ), - self.on_error, - ) - posthog_logger.capture_event("set_system_message", {"system_message": message}) - - def set_temperature(self, temperature: float): - self.session.autopilot.continue_sdk.config.temperature = temperature - create_async_task( - self.session.autopilot.set_config_attr( - ["temperature"], create_float_node(temperature) - ), - self.on_error, - ) - posthog_logger.capture_event("set_temperature", {"temperature": temperature}) - - def set_model_for_role_from_index(self, role: str, index: int): - async def async_stuff(): - models = self.session.autopilot.continue_sdk.config.models - - # Set models in SDK - temp = models.default - models.default = models.saved[index] - models.saved[index] = temp - await self.session.autopilot.continue_sdk.start_model(models.default) - - # Set models in config.py - JOINER = ",\n\t\t" - models_args = { - "saved": f"[{JOINER.join([display_llm_class(llm) for llm in models.saved])}]", - ("default" if role == "*" else role): display_llm_class(models.default), - } - - await self.session.autopilot.set_config_attr( - ["models"], - create_obj_node("Models", models_args), - ) - - for other_role in ALL_MODEL_ROLES: - if other_role != "default": - models.__setattr__(other_role, models.default) - - await self.session.autopilot.continue_sdk.update_ui() - - create_async_task(async_stuff(), self.on_error) - - def add_model_for_role(self, role: str, model_class: str, model: Any): - models = self.session.autopilot.continue_sdk.config.models - - model_copy = model.copy() - if "api_key" in model_copy: - del model_copy["api_key"] - if "hf_token" in model_copy: - del model_copy["hf_token"] - posthog_logger.capture_event( - "select_model_for_role", - {"role": role, "model_class": model_class, "model": model_copy}, - ) - - if role == "*": - - async def async_stuff(): - # Remove all previous models in roles and place in saved - saved_models = models.saved - existing_saved_models = set( - [display_llm_class(llm) for llm in saved_models] - ) - for role in ALL_MODEL_ROLES: - val = models.__getattribute__(role) - if ( - val is not None - and display_llm_class(val) not in existing_saved_models - ): - saved_models.append(val) - existing_saved_models.add(display_llm_class(val)) - models.__setattr__(role, None) - - # Add the requisite import to config.py - add_config_import( - f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" - ) - if "template_messages" in model: - add_config_import( - f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}" - ) - - # Set and start the new default model - - if "template_messages" in model: - model["template_messages"] = { - "llama2_template_messages": llama2_template_messages, - "template_alpaca_messages": template_alpaca_messages, - }[model["template_messages"]] - new_model = MODEL_CLASSES[model_class](**model) - models.default = new_model - await self.session.autopilot.continue_sdk.start_model(models.default) - - # Construct and set the new models object - JOINER = ",\n\t\t" - saved_model_strings = set( - [display_llm_class(llm) for llm in saved_models] - ) - models_args = { - "default": display_llm_class(models.default, True), - "saved": f"[{JOINER.join(saved_model_strings)}]", - } - - await self.session.autopilot.set_config_attr( - ["models"], - create_obj_node("Models", models_args), - ) - - # Set all roles (in-memory) to the new default model - for role in ALL_MODEL_ROLES: - if role != "default": - models.__setattr__(role, models.default) - - # Display setup help - await self.session.autopilot.continue_sdk.run_step( - SetupModelStep(model_class=model_class) - ) - - create_async_task(async_stuff(), self.on_error) - else: - # TODO - pass - - def save_context_group(self, title: str, context_items: List[ContextItem]): - create_async_task( - self.session.autopilot.save_context_group(title, context_items), - self.on_error, - ) - - def select_context_group(self, id: str): - create_async_task( - self.session.autopilot.select_context_group(id), self.on_error - ) - - def delete_context_group(self, id: str): - create_async_task( - self.session.autopilot.delete_context_group(id), self.on_error - ) - - -@router.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, session: Session = Depends(websocket_session) -): - try: - logger.debug(f"Received websocket connection at url: {websocket.url}") - await websocket.accept() - - logger.debug("Session started") - session_manager.register_websocket(session.session_id, websocket) - protocol = GUIProtocolServer(session) - protocol.websocket = websocket - - # Update any history that may have happened before connection - await protocol.session.autopilot.update_subscribers() - - while AppStatus.should_exit is False: - message = await websocket.receive_text() - logger.debug(f"Received GUI message {message}") - if isinstance(message, str): - message = json.loads(message) - - if "messageType" not in message or "data" not in message: - continue # :o - message_type = message["messageType"] - data = message["data"] - - protocol.handle_json(message_type, data) - except WebSocketDisconnect: - logger.debug("GUI websocket disconnected") - except Exception as e: - # Log, send to PostHog, and send to GUI - logger.debug(f"ERROR in gui websocket: {e}") - err_msg = "\n".join(traceback.format_exception(e)) - posthog_logger.capture_event( - "gui_error", - {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg}, - ) - - await session.autopilot.ide.showMessage(err_msg) - - raise e - finally: - logger.debug("Closing gui websocket") - if websocket.client_state != WebSocketState.DISCONNECTED: - await websocket.close() - - await session_manager.persist_session(session.session_id) - await session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py deleted file mode 100644 index 6a4dc738..00000000 --- a/continuedev/src/continuedev/server/ide.py +++ /dev/null @@ -1,673 +0,0 @@ -# This is a separate server from server/main.py -import asyncio -import json -import os -import traceback -import uuid -from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, TypeVar, Union - -import nest_asyncio -from fastapi import APIRouter, WebSocket -from pydantic import BaseModel -from starlette.websockets import WebSocketDisconnect, WebSocketState -from uvicorn.main import Server - -from ..core.main import ContinueCustomException -from ..libs.util.create_async_task import create_async_task -from ..libs.util.devdata import dev_data_logger -from ..libs.util.logging import logger -from ..libs.util.queue import AsyncSubscriptionQueue -from ..libs.util.telemetry import posthog_logger -from ..models.filesystem import ( - EditDiff, - FileSystem, - RangeInFile, - RangeInFileWithContents, - RealFileSystem, -) -from ..models.filesystem_edit import ( - AddDirectory, - AddFile, - DeleteDirectory, - DeleteFile, - FileEdit, - FileEditWithFullContents, - FileSystemEdit, - RenameDirectory, - RenameFile, - SequentialFileSystemEdit, -) -from ..plugins.steps.core.core import DisplayErrorStep -from .gui import session_manager -from .ide_protocol import AbstractIdeProtocolServer -from .session_manager import SessionManager - -nest_asyncio.apply() - - -router = APIRouter(prefix="/ide", tags=["ide"]) - - -# Graceful shutdown by closing websockets -original_handler = Server.handle_exit - - -class AppStatus: - should_exit = False - - @staticmethod - def handle_exit(*args, **kwargs): - AppStatus.should_exit = True - logger.debug("Shutting down") - original_handler(*args, **kwargs) - - -Server.handle_exit = AppStatus.handle_exit - - -# TYPES # - - -class FileEditsUpdate(BaseModel): - fileEdits: List[FileEditWithFullContents] - - -class OpenFilesResponse(BaseModel): - openFiles: List[str] - - -class VisibleFilesResponse(BaseModel): - visibleFiles: List[str] - - -class HighlightedCodeResponse(BaseModel): - highlightedCode: List[RangeInFile] - - -class ShowSuggestionRequest(BaseModel): - suggestion: FileEdit - - -class ShowSuggestionResponse(BaseModel): - suggestion: FileEdit - accepted: bool - - -class ReadFileResponse(BaseModel): - contents: str - - -class EditFileResponse(BaseModel): - fileEdit: FileEditWithFullContents - - -class WorkspaceDirectoryResponse(BaseModel): - workspaceDirectory: str - - -class GetUserSecretResponse(BaseModel): - value: str - - -class RunCommandResponse(BaseModel): - output: str = "" - - -class UniqueIdResponse(BaseModel): - uniqueId: str - - -class TerminalContentsResponse(BaseModel): - contents: str - - -class ListDirectoryContentsResponse(BaseModel): - contents: List[str] - - -class FileExistsResponse(BaseModel): - exists: bool - - -T = TypeVar("T", bound=BaseModel) - - -class cached_property_no_none: - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner): - if instance is None: - return self - value = self.func(instance) - if value is not None: - setattr(instance, self.func.__name__, value) - return value - - def __repr__(self): - return f"<cached_property_no_none '{self.func.__name__}'>" - - -class IdeProtocolServer(AbstractIdeProtocolServer): - websocket: WebSocket - session_manager: SessionManager - sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() - session_id: Union[str, None] = None - - ide_info: Optional[Dict] = None - - def __init__(self, session_manager: SessionManager, websocket: WebSocket): - self.websocket = websocket - self.session_manager = session_manager - - workspace_directory: str = None - unique_id: str = None - - async def initialize(self, session_id: str) -> List[str]: - self.session_id = session_id - await self._send_json("workspaceDirectory", {}) - await self._send_json("uniqueId", {}) - await self._send_json("ide", {}) - other_msgs = [] - while True: - msg_string = await self.websocket.receive_text() - message = json.loads(msg_string) - if "messageType" not in message or "data" not in message: - continue # <-- hey that's the name of this repo! - message_type = message["messageType"] - data = message["data"] - logger.debug(f"Received message while initializing {message_type}") - if message_type == "workspaceDirectory": - self.workspace_directory = data["workspaceDirectory"] - elif message_type == "uniqueId": - self.unique_id = data["uniqueId"] - elif message_type == "ide": - self.ide_info = data - else: - other_msgs.append(msg_string) - - if self.workspace_directory is not None and self.unique_id is not None: - break - return other_msgs - - async def _send_json(self, message_type: str, data: Any): - # TODO: You breakpointed here, set it to disconnected, and then saw - # that even after reloading, it couldn't connect the server. - # Is this because there is an IDE registered without a websocket? - # This shouldn't count as registered in that case. - try: - if self.websocket.application_state == WebSocketState.DISCONNECTED: - logger.debug( - f"Tried to send message, but websocket is disconnected: {message_type}" - ) - return - # logger.debug(f"Sending IDE message: {message_type}") - await self.websocket.send_json({"messageType": message_type, "data": data}) - except RuntimeError as e: - logger.warning(f"Error sending IDE message, websocket probably closed: {e}") - - async def _receive_json( - self, message_type: str, timeout: int = 20, message=None - ) -> Any: - try: - return await asyncio.wait_for( - self.sub_queue.get(message_type), timeout=timeout - ) - except asyncio.TimeoutError: - raise ContinueCustomException( - title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}", - message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}", - ) - - async def _send_and_receive_json( - self, data: Any, resp_model: Type[T], message_type: str - ) -> T: - await self._send_json(message_type, data) - resp = await self._receive_json(message_type, message=data) - return resp_model.parse_obj(resp) - - async def handle_json(self, message_type: str, data: Any): - if message_type == "getSessionId": - await self.getSessionId() - elif message_type == "setFileOpen": - await self.setFileOpen(data["filepath"], data["open"]) - elif message_type == "setSuggestionsLocked": - await self.setSuggestionsLocked(data["filepath"], data["locked"]) - elif message_type == "fileEdits": - fileEdits = list( - map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]) - ) - self.onFileEdits(fileEdits) - elif message_type == "highlightedCodePush": - self.onHighlightedCodeUpdate( - [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]], - edit=data.get("edit", False), - ) - elif message_type == "commandOutput": - output = data["output"] - self.onCommandOutput(output) - elif message_type == "debugTerminal": - content = data["contents"] - self.onDebugTerminal(content) - elif message_type == "acceptRejectSuggestion": - self.onAcceptRejectSuggestion(data["accepted"]) - elif message_type == "acceptRejectDiff": - self.onAcceptRejectDiff(data["accepted"], data["stepIndex"]) - elif message_type == "mainUserInput": - self.onMainUserInput(data["input"]) - elif message_type == "deleteAtIndex": - self.onDeleteAtIndex(data["index"]) - elif message_type in [ - "highlightedCode", - "openFiles", - "visibleFiles", - "readFile", - "editFile", - "getUserSecret", - "runCommand", - "getTerminalContents", - "listDirectoryContents", - "fileExists", - ]: - self.sub_queue.post(message_type, data) - elif message_type == "workspaceDirectory": - self.workspace_directory = data["workspaceDirectory"] - elif message_type == "uniqueId": - self.unique_id = data["uniqueId"] - elif message_type == "ide": - self.ide_info = data - elif message_type == "filesCreated": - self.onFilesCreated(data["filepaths"]) - elif message_type == "filesDeleted": - self.onFilesDeleted(data["filepaths"]) - elif message_type == "filesRenamed": - self.onFilesRenamed(data["old_filepaths"], data["new_filepaths"]) - elif message_type == "fileSaved": - self.onFileSaved(data["filepath"], data["contents"]) - else: - raise ValueError("Unknown message type", message_type) - - async def showSuggestion(self, file_edit: FileEdit): - await self._send_json("showSuggestion", {"edit": file_edit.dict()}) - - async def showDiff(self, filepath: str, replacement: str, step_index: int): - await self._send_json( - "showDiff", - { - "filepath": filepath, - "replacement": replacement, - "step_index": step_index, - }, - ) - - async def setFileOpen(self, filepath: str, open: bool = True): - # Autopilot needs access to this. - await self._send_json("setFileOpen", {"filepath": filepath, "open": open}) - - async def showMessage(self, message: str): - await self._send_json("showMessage", {"message": message}) - - async def showVirtualFile(self, name: str, contents: str): - await self._send_json("showVirtualFile", {"name": name, "contents": contents}) - - async def setSuggestionsLocked(self, filepath: str, locked: bool = True): - # Lock suggestions in the file so they don't ruin the offset before others are inserted - await self._send_json( - "setSuggestionsLocked", {"filepath": filepath, "locked": locked} - ) - - async def getSessionId(self): - new_session = await asyncio.wait_for( - self.session_manager.new_session(self, self.session_id), timeout=5 - ) - session_id = new_session.session_id - logger.debug(f"Sending session id: {session_id}") - await self._send_json("getSessionId", {"sessionId": session_id}) - - async def highlightCode(self, range_in_file: RangeInFile, color: str = "#00ff0022"): - await self._send_json( - "highlightCode", {"rangeInFile": range_in_file.dict(), "color": color} - ) - - async def runCommand(self, command: str) -> str: - return ( - await self._send_and_receive_json( - {"command": command}, RunCommandResponse, "runCommand" - ) - ).output - - async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: - ids = [str(uuid.uuid4()) for _ in suggestions] - for i in range(len(suggestions)): - self._send_json( - "showSuggestion", {"suggestion": suggestions[i], "suggestionId": ids[i]} - ) - responses = await asyncio.gather( - *[ - self._receive_json(ShowSuggestionResponse) - for i in range(len(suggestions)) - ] - ) # WORKING ON THIS FLOW HERE. Fine now to just await for response, instead of doing something fancy with a "waiting" state on the autopilot. - # Just need connect the suggestionId to the IDE (and the gui) - return any([r.accepted for r in responses]) - - def on_error(self, e: Exception) -> Coroutine: - err_msg = "\n".join(traceback.format_exception(e)) - e_title = e.__str__() or e.__repr__() - return self.showMessage(f"Error in Continue server: {e_title}\n {err_msg}") - - def onAcceptRejectSuggestion(self, accepted: bool): - posthog_logger.capture_event("accept_reject_suggestion", {"accepted": accepted}) - dev_data_logger.capture("accept_reject_suggestion", {"accepted": accepted}) - - def onAcceptRejectDiff(self, accepted: bool, step_index: int): - posthog_logger.capture_event("accept_reject_diff", {"accepted": accepted}) - dev_data_logger.capture("accept_reject_diff", {"accepted": accepted}) - - if not accepted: - if autopilot := self.__get_autopilot(): - create_async_task( - autopilot.reject_diff(step_index), - self.on_error, - ) - - def onFileSystemUpdate(self, update: FileSystemEdit): - # Access to Autopilot (so SessionManager) - pass - - def onCloseGUI(self, session_id: str): - # Access to SessionManager - pass - - def onOpenGUIRequest(self): - pass - - def __get_autopilot(self): - if self.session_id not in self.session_manager.sessions: - return None - - autopilot = self.session_manager.sessions[self.session_id].autopilot - return autopilot if autopilot.started else None - - def onFileEdits(self, edits: List[FileEditWithFullContents]): - if autopilot := self.__get_autopilot(): - pass - - def onDeleteAtIndex(self, index: int): - if autopilot := self.__get_autopilot(): - create_async_task(autopilot.delete_at_index(index), self.on_error) - - def onCommandOutput(self, output: str): - if autopilot := self.__get_autopilot(): - create_async_task(autopilot.handle_command_output(output), self.on_error) - - def onDebugTerminal(self, content: str): - if autopilot := self.__get_autopilot(): - create_async_task(autopilot.handle_debug_terminal(content), self.on_error) - - def onHighlightedCodeUpdate( - self, - range_in_files: List[RangeInFileWithContents], - edit: Optional[bool] = False, - ): - if autopilot := self.__get_autopilot(): - create_async_task( - autopilot.handle_highlighted_code(range_in_files, edit), self.on_error - ) - - ## Subscriptions ## - - _files_created_callbacks = [] - _files_deleted_callbacks = [] - _files_renamed_callbacks = [] - _file_saved_callbacks = [] - - def call_callback(self, callback, *args, **kwargs): - if asyncio.iscoroutinefunction(callback): - create_async_task(callback(*args, **kwargs), self.on_error) - else: - callback(*args, **kwargs) - - def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): - self._files_created_callbacks.append(callback) - - def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): - self._files_deleted_callbacks.append(callback) - - def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): - self._files_renamed_callbacks.append(callback) - - def subscribeToFileSaved(self, callback: Callable[[str, str], None]): - self._file_saved_callbacks.append(callback) - - def onFilesCreated(self, filepaths: List[str]): - for callback in self._files_created_callbacks: - self.call_callback(callback, filepaths) - - def onFilesDeleted(self, filepaths: List[str]): - for callback in self._files_deleted_callbacks: - self.call_callback(callback, filepaths) - - def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): - for callback in self._files_renamed_callbacks: - self.call_callback(callback, old_filepaths, new_filepaths) - - def onFileSaved(self, filepath: str, contents: str): - for callback in self._file_saved_callbacks: - self.call_callback(callback, filepath, contents) - - ## END Subscriptions ## - - def onMainUserInput(self, input: str): - if autopilot := self.__get_autopilot(): - create_async_task(autopilot.accept_user_input(input), self.on_error) - - # Request information. Session doesn't matter. - async def getOpenFiles(self) -> List[str]: - resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") - return resp.openFiles - - async def getVisibleFiles(self) -> List[str]: - resp = await self._send_and_receive_json( - {}, VisibleFilesResponse, "visibleFiles" - ) - return resp.visibleFiles - - async def getTerminalContents(self, commands: int = -1) -> str: - """Get the contents of the terminal, up to the last 'commands' commands, or all if commands is -1""" - resp = await self._send_and_receive_json( - {"commands": commands}, TerminalContentsResponse, "getTerminalContents" - ) - return resp.contents.strip() - - async def getHighlightedCode(self) -> List[RangeInFile]: - resp = await self._send_and_receive_json( - {}, HighlightedCodeResponse, "highlightedCode" - ) - return resp.highlightedCode - - async def readFile(self, filepath: str) -> str: - """Read a file""" - resp = await self._send_and_receive_json( - {"filepath": filepath}, ReadFileResponse, "readFile" - ) - return resp.contents - - async def fileExists(self, filepath: str) -> str: - """Check whether file exists""" - resp = await self._send_and_receive_json( - {"filepath": filepath}, FileExistsResponse, "fileExists" - ) - return resp.exists - - async def getUserSecret(self, key: str) -> str: - """Get a user secret""" - try: - resp = await self._send_and_receive_json( - {"key": key}, GetUserSecretResponse, "getUserSecret" - ) - return resp.value - except Exception as e: - logger.debug(f"Error getting user secret: {e}") - return "" - - async def saveFile(self, filepath: str): - """Save a file""" - await self._send_json("saveFile", {"filepath": filepath}) - - async def readRangeInFile(self, range_in_file: RangeInFile) -> str: - """Read a range in a file""" - full_contents = await self.readFile(range_in_file.filepath) - return FileSystem.read_range_in_str(full_contents, range_in_file.range) - - async def editFile(self, edit: FileEdit) -> FileEditWithFullContents: - """Edit a file""" - resp = await self._send_and_receive_json( - {"edit": edit.dict()}, EditFileResponse, "editFile" - ) - return resp.fileEdit - - async def listDirectoryContents( - self, directory: str, recursive: bool = False - ) -> List[str]: - """List the contents of a directory""" - resp = await self._send_and_receive_json( - {"directory": directory, "recursive": recursive}, - ListDirectoryContentsResponse, - "listDirectoryContents", - ) - return resp.contents - - async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: - """Apply a file edit""" - backward = None - fs = RealFileSystem() - if isinstance(edit, FileEdit): - file_edit = await self.editFile(edit) - _, diff = FileSystem.apply_edit_to_str( - file_edit.fileContents, file_edit.fileEdit - ) - backward = diff.backward - elif isinstance(edit, AddFile): - fs.write(edit.filepath, edit.content) - backward = DeleteFile(filepath=edit.filepath) - elif isinstance(edit, DeleteFile): - contents = await self.readFile(edit.filepath) - backward = AddFile(filepath=edit.filepath, content=contents) - fs.delete_file(edit.filepath) - elif isinstance(edit, RenameFile): - fs.rename_file(edit.filepath, edit.new_filepath) - backward = RenameFile( - filepath=edit.new_filepath, new_filepath=edit.filepath - ) - elif isinstance(edit, AddDirectory): - fs.add_directory(edit.path) - backward = DeleteDirectory(path=edit.path) - elif isinstance(edit, DeleteDirectory): - # This isn't atomic! - backward_edits = [] - for root, dirs, files in os.walk(edit.path, topdown=False): - for f in files: - path = os.path.join(root, f) - edit_diff = await self.applyFileSystemEdit( - DeleteFile(filepath=path) - ) - backward_edits.append(edit_diff) - for d in dirs: - path = os.path.join(root, d) - edit_diff = await self.applyFileSystemEdit( - DeleteDirectory(path=path) - ) - backward_edits.append(edit_diff) - - edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=edit.path)) - backward_edits.append(edit_diff) - backward_edits.reverse() - backward = SequentialFileSystemEdit(edits=backward_edits) - elif isinstance(edit, RenameDirectory): - fs.rename_directory(edit.path, edit.new_path) - backward = RenameDirectory(path=edit.new_path, new_path=edit.path) - elif isinstance(edit, FileSystemEdit): - diffs = [] - for edit in edit.next_edit(): - edit_diff = await self.applyFileSystemEdit(edit) - diffs.append(edit_diff) - backward = EditDiff.from_sequence(diffs=diffs).backward - else: - raise TypeError("Unknown FileSystemEdit type: " + str(type(edit))) - - return EditDiff(forward=edit, backward=backward) - - -@router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, session_id: str = None): - try: - # Accept the websocket connection - await websocket.accept() - logger.debug(f"Accepted websocket connection from {websocket.client}") - await websocket.send_json({"messageType": "connected", "data": {}}) - - # Message handler - def handle_msg(msg): - try: - message = json.loads(msg) - except json.JSONDecodeError: - logger.critical(f"Error decoding json: {msg}") - return - - if "messageType" not in message or "data" not in message: - return - message_type = message["messageType"] - data = message["data"] - - # logger.debug(f"Received IDE message: {message_type}") - create_async_task( - ideProtocolServer.handle_json(message_type, data), - ideProtocolServer.on_error, - ) - - # Initialize the IDE Protocol Server - ideProtocolServer = IdeProtocolServer(session_manager, websocket) - if session_id is not None: - session_manager.registered_ides[session_id] = ideProtocolServer - other_msgs = await ideProtocolServer.initialize(session_id) - posthog_logger.capture_event( - "session_started", {"session_id": ideProtocolServer.session_id} - ) - - for other_msg in other_msgs: - handle_msg(other_msg) - - # Handle messages - while AppStatus.should_exit is False: - message = await websocket.receive_text() - handle_msg(message) - - except WebSocketDisconnect: - logger.debug("IDE websocket disconnected") - except Exception as e: - logger.debug(f"Error in ide websocket: {e}") - err_msg = "\n".join(traceback.format_exception(e)) - posthog_logger.capture_event( - "gui_error", - {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg}, - ) - - if session_id is not None and session_id in session_manager.sessions: - await session_manager.sessions[session_id].autopilot.continue_sdk.run_step( - DisplayErrorStep.from_exception(e) - ) - elif ideProtocolServer is not None: - await ideProtocolServer.showMessage(f"Error in Continue server: {err_msg}") - - raise e - finally: - logger.debug("Closing ide websocket") - if websocket.client_state != WebSocketState.DISCONNECTED: - await websocket.close() - - posthog_logger.capture_event( - "session_ended", {"session_id": ideProtocolServer.session_id} - ) - if ideProtocolServer.session_id in session_manager.registered_ides: - session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py deleted file mode 100644 index 832dd338..00000000 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ /dev/null @@ -1,170 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union - -from fastapi import WebSocket - -from ..models.filesystem import RangeInFile, RangeInFileWithContents -from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit - - -class AbstractIdeProtocolServer(ABC): - websocket: WebSocket - session_id: Union[str, None] - ide_info: Optional[Dict] = None - - @abstractmethod - async def handle_json(self, data: Any): - """Handle a json message""" - - @abstractmethod - def showSuggestion(self, file_edit: FileEdit): - """Show a suggestion to the user""" - - @abstractmethod - async def setFileOpen(self, filepath: str, open: bool = True): - """Set whether a file is open""" - - @abstractmethod - async def showMessage(self, message: str): - """Show a message to the user""" - - @abstractmethod - async def showVirtualFile(self, name: str, contents: str): - """Show a virtual file""" - - @abstractmethod - async def setSuggestionsLocked(self, filepath: str, locked: bool = True): - """Set whether suggestions are locked""" - - @abstractmethod - async def getSessionId(self): - """Get a new session ID""" - - @abstractmethod - async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: - """Show suggestions to the user and wait for a response""" - - @abstractmethod - def onAcceptRejectSuggestion(self, accepted: bool): - """Called when the user accepts or rejects a suggestion""" - - @abstractmethod - def onFileSystemUpdate(self, update: FileSystemEdit): - """Called when a file system update is received""" - - @abstractmethod - def onCloseGUI(self, session_id: str): - """Called when a GUI is closed""" - - @abstractmethod - def onOpenGUIRequest(self): - """Called when a GUI is requested to be opened""" - - @abstractmethod - async def getOpenFiles(self) -> List[str]: - """Get a list of open files""" - - @abstractmethod - async def getVisibleFiles(self) -> List[str]: - """Get a list of visible files""" - - @abstractmethod - async def getHighlightedCode(self) -> List[RangeInFile]: - """Get a list of highlighted code""" - - @abstractmethod - async def readFile(self, filepath: str) -> str: - """Read a file""" - - @abstractmethod - async def readRangeInFile(self, range_in_file: RangeInFile) -> str: - """Read a range in a file""" - - @abstractmethod - async def editFile(self, edit: FileEdit): - """Edit a file""" - - @abstractmethod - async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: - """Apply a file edit""" - - @abstractmethod - async def saveFile(self, filepath: str): - """Save a file""" - - @abstractmethod - async def getUserSecret(self, key: str): - """Get a user secret""" - - @abstractmethod - async def highlightCode(self, range_in_file: RangeInFile, color: str): - """Highlight code""" - - @abstractmethod - async def runCommand(self, command: str) -> str: - """Run a command""" - - @abstractmethod - def onHighlightedCodeUpdate( - self, - range_in_files: List[RangeInFileWithContents], - edit: Optional[bool] = False, - ): - """Called when highlighted code is updated""" - - @abstractmethod - def onDeleteAtIndex(self, index: int): - """Called when a step is deleted at a given index""" - - @abstractmethod - async def showDiff(self, filepath: str, replacement: str, step_index: int): - """Show a diff""" - - @abstractmethod - def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): - """Subscribe to files created event""" - - @abstractmethod - def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): - """Subscribe to files deleted event""" - - @abstractmethod - def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): - """Subscribe to files renamed event""" - - @abstractmethod - def subscribeToFileSaved(self, callback: Callable[[str, str], None]): - """Subscribe to file saved event""" - - @abstractmethod - def onFilesCreated(self, filepaths: List[str]): - """Called when files are created""" - - @abstractmethod - def onFilesDeleted(self, filepaths: List[str]): - """Called when files are deleted""" - - @abstractmethod - def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): - """Called when files are renamed""" - - @abstractmethod - def onFileSaved(self, filepath: str, contents: str): - """Called when a file is saved""" - - @abstractmethod - async def listDirectoryContents( - self, directory: str, recursive: bool = False - ) -> List[str]: - """List directory contents""" - - @abstractmethod - async def fileExists(self, filepath: str) -> str: - """Check if a file exists""" - - @abstractmethod - async def getTerminalContents(self, commands: int = -1) -> str: - """Get the terminal contents""" - - workspace_directory: str - unique_id: str diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py deleted file mode 100644 index c5540d7d..00000000 --- a/continuedev/src/continuedev/server/main.py +++ /dev/null @@ -1,109 +0,0 @@ -import argparse -import asyncio -import atexit -from contextlib import asynccontextmanager -from typing import Optional - -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -from ..libs.util.create_async_task import create_async_task -from ..libs.util.logging import logger -from .gui import router as gui_router -from .ide import router as ide_router -from .meilisearch_server import start_meilisearch, stop_meilisearch -from .session_manager import router as sessions_router -from .session_manager import session_manager - -meilisearch_url_global = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - async def on_err(e): - logger.warning(f"Error starting MeiliSearch: {e}") - - try: - # start meilisearch without blocking server startup - create_async_task(start_meilisearch(url=meilisearch_url_global), on_err) - except Exception as e: - logger.warning(f"Error starting MeiliSearch: {e}") - - yield - stop_meilisearch() - - -app = FastAPI(lifespan=lifespan) - -app.include_router(ide_router) -app.include_router(gui_router) -app.include_router(sessions_router) - -# Add CORS support -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.get("/health") -def health(): - logger.debug("Health check") - return {"status": "ok"} - - -def run_server( - port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None -): - try: - global meilisearch_url_global - - meilisearch_url_global = meilisearch_url - - config = uvicorn.Config(app, host=host, port=port) - server = uvicorn.Server(config) - server.run() - except PermissionError as e: - logger.critical( - f"Error starting Continue server: {e}. " - f"This means that port {port} is already in use, and is usually caused by another instance of the Continue server already running." - ) - cleanup() - raise e - - except Exception as e: - logger.critical(f"Error starting Continue server: {e}") - cleanup() - raise e - - -async def cleanup_coroutine(): - logger.debug("------ Cleaning Up ------") - for session_id in session_manager.sessions: - await session_manager.persist_session(session_id) - - -def cleanup(): - loop = asyncio.new_event_loop() - loop.run_until_complete(cleanup_coroutine()) - loop.close() - - -atexit.register(cleanup) - -if __name__ == "__main__": - try: - # add cli arg for server port - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", help="server port", type=int, default=65432) - parser.add_argument("--host", help="server host", type=str, default="127.0.0.1") - args = parser.parse_args() - except Exception as e: - logger.critical(f"Error parsing command line arguments: {e}") - raise e - - run_server(args.port, args.host) diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py deleted file mode 100644 index 6ce4d61c..00000000 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ /dev/null @@ -1,195 +0,0 @@ -import asyncio -import os -import shutil -import subprocess -from typing import Optional - -import aiofiles -import aiohttp -import psutil -from meilisearch_python_async import Client - -from ..libs.util.logging import logger -from ..libs.util.paths import getMeilisearchExePath, getServerFolderPath - - -async def download_file(url: str, filename: str): - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: - if resp.status == 200: - f = await aiofiles.open(filename, mode="wb") - await f.write(await resp.read()) - await f.close() - - -async def download_meilisearch(): - """ - Downloads MeiliSearch. - """ - - serverPath = getServerFolderPath() - logger.debug("Downloading MeiliSearch...") - - if os.name == "nt": - download_url = "https://github.com/meilisearch/meilisearch/releases/download/v1.3.2/meilisearch-windows-amd64.exe" - download_path = getMeilisearchExePath() - if not os.path.exists(download_path): - await download_file(download_url, download_path) - # subprocess.run( - # f"curl -L {download_url} -o {download_path}", - # shell=True, - # check=True, - # cwd=serverPath, - # ) - else: - subprocess.run( - "curl -L https://install.meilisearch.com | sh", - shell=True, - check=True, - cwd=serverPath, - ) - - -async def ensure_meilisearch_installed() -> bool: - """ - Checks if MeiliSearch is installed. - - Returns a bool indicating whether it was installed to begin with. - """ - serverPath = getServerFolderPath() - meilisearchPath = getMeilisearchExePath() - dumpsPath = os.path.join(serverPath, "dumps") - dataMsPath = os.path.join(serverPath, "data.ms") - - paths = [meilisearchPath, dumpsPath, dataMsPath] - - existing_paths = set() - non_existing_paths = set() - for path in paths: - if os.path.exists(path): - existing_paths.add(path) - else: - non_existing_paths.add(path) - - if len(non_existing_paths) > 0: - # Clear the meilisearch binary - if meilisearchPath in existing_paths: - try: - os.remove(meilisearchPath) - except: - pass - existing_paths.remove(meilisearchPath) - - await download_meilisearch() - - # Clear the existing directories - for p in existing_paths: - shutil.rmtree(p, ignore_errors=True) - - return False - - return True - - -meilisearch_process = None -DEFAULT_MEILISEARCH_URL = "http://localhost:7700" -meilisearch_url = DEFAULT_MEILISEARCH_URL - - -def get_meilisearch_url(): - return meilisearch_url - - -async def check_meilisearch_running() -> bool: - """ - Checks if MeiliSearch is running. - """ - - try: - async with Client(meilisearch_url) as client: - try: - resp = await client.health() - if resp.status != "available": - return False - return True - except Exception: - return False - except Exception: - return False - - -async def poll_meilisearch_running(frequency: int = 0.1) -> bool: - """ - Polls MeiliSearch to see if it is running. - """ - while True: - if await check_meilisearch_running(): - return True - await asyncio.sleep(frequency) - - -async def start_meilisearch(url: Optional[str] = None): - """ - Starts the MeiliSearch server, wait for it. - """ - global meilisearch_process, meilisearch_url - - if url is not None: - logger.debug("Using MeiliSearch at URL: " + url) - meilisearch_url = url - return - - serverPath = getServerFolderPath() - - # Check if MeiliSearch is installed, if not download - was_already_installed = await ensure_meilisearch_installed() - - # Check if MeiliSearch is running - if not await check_meilisearch_running() or not was_already_installed: - logger.debug("Starting MeiliSearch...") - binary_name = "meilisearch" if os.name == "nt" else "./meilisearch" - meilisearch_process = subprocess.Popen( - [binary_name, "--no-analytics"], - cwd=serverPath, - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT, - close_fds=True, - start_new_session=True, - shell=True, - ) - - logger.debug("Meilisearch started") - - -def stop_meilisearch(): - """ - Stops the MeiliSearch server. - """ - global meilisearch_process - if meilisearch_process is not None: - meilisearch_process.terminate() - meilisearch_process.wait() - meilisearch_process = None - - -def kill_proc(port): - for proc in psutil.process_iter(): - try: - for conns in proc.connections(kind="inet"): - if conns.laddr.port == port: - proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL - except psutil.AccessDenied: - logger.warning(f"Failed to kill process on port {port} (access denied)") - return - except psutil.ZombieProcess: - logger.warning(f"Failed to kill process on port {port} (zombie process)") - return - except psutil.NoSuchProcess: - logger.warning(f"Failed to kill process on port {port} (no such process)") - return - - -async def restart_meilisearch(): - stop_meilisearch() - kill_proc(7700) - await start_meilisearch(url=meilisearch_url) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py deleted file mode 100644 index f0080104..00000000 --- a/continuedev/src/continuedev/server/session_manager.py +++ /dev/null @@ -1,192 +0,0 @@ -import json -import os -import traceback -from typing import Any, Coroutine, Dict, Optional, Union -from uuid import uuid4 - -from fastapi import APIRouter, WebSocket -from fastapi.websockets import WebSocketState - -from ..core.autopilot import Autopilot -from ..core.config import ContinueConfig -from ..core.main import FullState -from ..libs.util.create_async_task import create_async_task -from ..libs.util.logging import logger -from ..libs.util.paths import ( - getSessionFilePath, - getSessionsFolderPath, - getSessionsListFilePath, -) -from .ide_protocol import AbstractIdeProtocolServer - -router = APIRouter(prefix="/sessions", tags=["sessions"]) - - -class Session: - session_id: str - autopilot: Autopilot - # The GUI websocket for the session - ws: Union[WebSocket, None] - - def __init__(self, session_id: str, autopilot: Autopilot): - self.session_id = session_id - self.autopilot = autopilot - self.ws = None - - -class SessionManager: - sessions: Dict[str, Session] = {} - # Mapping of session_id to IDE, where the IDE is still alive - registered_ides: Dict[str, AbstractIdeProtocolServer] = {} - - async def get_session(self, session_id: str) -> Session: - if session_id not in self.sessions: - # Check then whether it is persisted by listing all files in the sessions folder - # And only if the IDE is still alive - sessions_folder = getSessionsFolderPath() - session_files = os.listdir(sessions_folder) - if ( - f"{session_id}.json" in session_files - and session_id in self.registered_ides - ): - if self.registered_ides[session_id].session_id is not None: - return await self.new_session( - self.registered_ides[session_id], session_id=session_id - ) - - raise KeyError("Session ID not recognized", session_id) - return self.sessions[session_id] - - async def new_session( - self, - ide: AbstractIdeProtocolServer, - session_id: Optional[str] = None, - config: Optional[ContinueConfig] = None, - ) -> Session: - logger.debug(f"New session: {session_id}") - - # Load the persisted state (not being used right now) - full_state = None - if session_id is not None and os.path.exists(getSessionFilePath(session_id)): - with open(getSessionFilePath(session_id), "r") as f: - full_state = FullState(**json.load(f)) - - # Register the session and ide (do this first so that the autopilot can access the session) - autopilot = Autopilot(ide=ide) - session_id = session_id or str(uuid4()) - ide.session_id = session_id - session = Session(session_id=session_id, autopilot=autopilot) - self.sessions[session_id] = session - self.registered_ides[session_id] = ide - - # Set up the autopilot to update the GUI - async def on_update(state: FullState): - await session_manager.send_ws_data( - session_id, "state_update", {"state": state.dict()} - ) - - autopilot.on_update(on_update) - - # Start the autopilot (must be after session is added to sessions) and the policy - try: - await autopilot.start(full_state=full_state, config=config) - except Exception as e: - await ide.on_error(e) - - def on_error(e: Exception) -> Coroutine: - err_msg = "\n".join(traceback.format_exception(e)) - return ide.showMessage(f"Error in Continue server: {err_msg}") - - create_async_task(autopilot.run_policy(), on_error) - return session - - async def remove_session(self, session_id: str): - logger.debug(f"Removing session: {session_id}") - if session_id in self.sessions: - if ( - session_id in self.registered_ides - and self.registered_ides[session_id] is not None - ): - ws_to_close = self.registered_ides[session_id].websocket - if ( - ws_to_close is not None - and ws_to_close.client_state != WebSocketState.DISCONNECTED - ): - await self.sessions[session_id].autopilot.ide.websocket.close() - - del self.sessions[session_id] - - async def persist_session(self, session_id: str): - """Save the session's FullState as a json file""" - full_state = await self.sessions[session_id].autopilot.get_full_state() - if full_state.session_info is None: - return - - with open(getSessionFilePath(session_id), "w") as f: - json.dump(full_state.dict(), f) - - # Read and update the sessions list - with open(getSessionsListFilePath(), "r") as f: - try: - sessions_list = json.load(f) - except json.JSONDecodeError: - raise Exception( - f"It looks like there is a JSON formatting error in your sessions.json file ({getSessionsListFilePath()}). Please fix this before creating a new session." - ) - - session_ids = [s["session_id"] for s in sessions_list] - if session_id not in session_ids: - sessions_list.append(full_state.session_info.dict()) - - for session_info in sessions_list: - if "workspace_directory" not in session_info: - session_info["workspace_directory"] = "" - - with open(getSessionsListFilePath(), "w") as f: - json.dump(sessions_list, f) - - async def load_session( - self, old_session_id: str, new_session_id: Optional[str] = None - ) -> str: - """Load the session's FullState from a json file""" - - # First persist the current state - await self.persist_session(old_session_id) - - # Delete the old session, but keep the IDE - ide = self.registered_ides[old_session_id] - del self.registered_ides[old_session_id] - - # Start the new session - new_session = await self.new_session(ide, session_id=new_session_id) - return new_session.session_id - - def register_websocket(self, session_id: str, ws: WebSocket): - self.sessions[session_id].ws = ws - logger.debug(f"Registered websocket for session {session_id}") - - async def send_ws_data(self, session_id: str, message_type: str, data: Any): - if session_id not in self.sessions: - logger.warning(f"Session {session_id} not found") - return - if self.sessions[session_id].ws is None: - return - - await self.sessions[session_id].ws.send_json( - {"messageType": message_type, "data": data} - ) - - -session_manager = SessionManager() - - -@router.get("/list") -async def list_sessions(): - """List all sessions""" - sessions_list_file = getSessionsListFilePath() - if not os.path.exists(sessions_list_file): - print("Returning empty sessions list") - return [] - sessions = json.load(open(sessions_list_file, "r")) - print("Returning sessions list: ", sessions) - return sessions |