summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-10-09 18:37:27 -0700
committerGitHub <noreply@github.com>2023-10-09 18:37:27 -0700
commitf09150617ed2454f3074bcf93f53aae5ae637d40 (patch)
tree5cfe614a64d921dfe58b049f426d67a8b832c71f /continuedev/src/continuedev/server
parent985304a213f620cdff3f8f65f74ed7e3b79be29d (diff)
downloadsncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.gz
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.bz2
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.zip
Preview (#541)
* Strong typing (#533) * refactor: :recycle: get rid of continuedev.src.continuedev structure * refactor: :recycle: switching back to server folder * feat: :sparkles: make config.py imports shorter * feat: :bookmark: publish as pre-release vscode extension * refactor: :recycle: refactor and add more completion params to ui * build: :building_construction: download from preview S3 * fix: :bug: fix paths * fix: :green_heart: package:pre-release * ci: :green_heart: more time for tests * fix: :green_heart: fix build scripts * fix: :bug: fix import in run.py * fix: :bookmark: update version to try again * ci: 💚 Update package.json version [skip ci] * refactor: :fire: don't check for old extensions version * fix: :bug: small bug fixes * fix: :bug: fix config.py import paths * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: platform-specific builds test #1 * feat: :green_heart: ship with binary * fix: :green_heart: fix copy statement to include.exe for windows * fix: :green_heart: cd extension before packaging * chore: :loud_sound: count tokens generated * fix: :green_heart: remove npm_config_arch * fix: :green_heart: publish as pre-release! * chore: :bookmark: update version * perf: :green_heart: hardcode distro paths * fix: :bug: fix yaml syntax error * chore: :bookmark: update version * fix: :green_heart: update permissions and version * feat: :bug: kill old server if needed * feat: :lipstick: update marketplace icon for pre-release * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: auto-reload for config.py * feat: :wrench: update default config.py imports * feat: :sparkles: codelens in config.py * feat: :sparkles: select model param count from UI * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: more model options, ollama error handling * perf: :zap: don't show server loading immediately * fix: :bug: fixing small UI details * ci: 💚 Update package.json version [skip ci] * feat: :rocket: headers param on LLM class * fix: :bug: fix headers for openai.;y * feat: :sparkles: highlight code on cmd+shift+L * ci: 💚 Update package.json version [skip ci] * feat: :lipstick: sticky top bar in gui.tsx * fix: :loud_sound: websocket logging and horizontal scrollbar * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: allow AzureOpenAI Service through GGML * ci: 💚 Update package.json version [skip ci] * fix: :bug: fix automigration * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: upload binaries in ci, download apple silicon * chore: :fire: remove notes * fix: :green_heart: use curl to download binary * fix: :green_heart: set permissions on apple silicon binary * fix: :green_heart: testing * fix: :green_heart: cleanup file * fix: :green_heart: fix preview.yaml * fix: :green_heart: only upload once per binary * fix: :green_heart: install rosetta * ci: :green_heart: download binary after tests * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: prepare ci for merge to main --------- Co-authored-by: GitHub Action <action@github.com>
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py453
-rw-r--r--continuedev/src/continuedev/server/ide.py673
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py170
-rw-r--r--continuedev/src/continuedev/server/main.py109
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py195
-rw-r--r--continuedev/src/continuedev/server/session_manager.py192
6 files changed, 0 insertions, 1792 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
deleted file mode 100644
index 26fcbd42..00000000
--- a/continuedev/src/continuedev/server/gui.py
+++ /dev/null
@@ -1,453 +0,0 @@
-import asyncio
-import json
-import traceback
-from typing import Any, List, Optional, Type, TypeVar
-
-from fastapi import APIRouter, Depends, WebSocket
-from pydantic import BaseModel
-from starlette.websockets import WebSocketDisconnect, WebSocketState
-from uvicorn.main import Server
-
-from ..core.main import ContextItem
-from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
-from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages
-from ..libs.util.create_async_task import create_async_task
-from ..libs.util.edit_config import (
- add_config_import,
- create_float_node,
- create_obj_node,
- create_string_node,
- display_llm_class,
-)
-from ..libs.util.logging import logger
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..libs.util.telemetry import posthog_logger
-from ..plugins.steps.core.core import DisplayErrorStep
-from ..plugins.steps.setup_model import SetupModelStep
-from .session_manager import Session, session_manager
-
-router = APIRouter(prefix="/gui", tags=["gui"])
-
-# Graceful shutdown by closing websockets
-original_handler = Server.handle_exit
-
-
-class AppStatus:
- should_exit = False
-
- @staticmethod
- def handle_exit(*args, **kwargs):
- AppStatus.should_exit = True
- logger.debug("Shutting down")
- original_handler(*args, **kwargs)
-
-
-Server.handle_exit = AppStatus.handle_exit
-
-
-async def websocket_session(session_id: str) -> Session:
- return await session_manager.get_session(session_id)
-
-
-T = TypeVar("T", bound=BaseModel)
-
-# You should probably abstract away the websocket stuff into a separate class
-
-
-class GUIProtocolServer:
- websocket: WebSocket
- session: Session
- sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
-
- def __init__(self, session: Session):
- self.session = session
-
- async def _send_json(self, message_type: str, data: Any):
- if self.websocket.application_state == WebSocketState.DISCONNECTED:
- return
- await self.websocket.send_json({"messageType": message_type, "data": data})
-
- async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
- try:
- return await asyncio.wait_for(
- self.sub_queue.get(message_type), timeout=timeout
- )
- except asyncio.TimeoutError:
- raise Exception("GUI Protocol _receive_json timed out after 20 seconds")
-
- async def _send_and_receive_json(
- self, data: Any, resp_model: Type[T], message_type: str
- ) -> T:
- await self._send_json(message_type, data)
- resp = await self._receive_json(message_type)
- return resp_model.parse_obj(resp)
-
- def on_error(self, e: Exception):
- return self.session.autopilot.continue_sdk.run_step(
- DisplayErrorStep.from_exception(e)
- )
-
- def handle_json(self, message_type: str, data: Any):
- if message_type == "main_input":
- self.on_main_input(data["input"])
- elif message_type == "step_user_input":
- self.on_step_user_input(data["input"], data["index"])
- elif message_type == "refinement_input":
- self.on_refinement_input(data["input"], data["index"])
- elif message_type == "reverse_to_index":
- self.on_reverse_to_index(data["index"])
- elif message_type == "retry_at_index":
- self.on_retry_at_index(data["index"])
- elif message_type == "clear_history":
- self.on_clear_history()
- elif message_type == "set_current_session_title":
- self.set_current_session_title(data["title"])
- elif message_type == "delete_at_index":
- self.on_delete_at_index(data["index"])
- elif message_type == "delete_context_with_ids":
- self.on_delete_context_with_ids(data["ids"], data.get("index", None))
- elif message_type == "toggle_adding_highlighted_code":
- self.on_toggle_adding_highlighted_code()
- elif message_type == "set_editing_at_ids":
- self.on_set_editing_at_ids(data["ids"])
- elif message_type == "show_logs_at_index":
- self.on_show_logs_at_index(data["index"])
- elif message_type == "show_context_virtual_file":
- self.show_context_virtual_file(data.get("index", None))
- elif message_type == "select_context_item":
- self.select_context_item(data["id"], data["query"])
- elif message_type == "select_context_item_at_index":
- self.select_context_item_at_index(data["id"], data["query"], data["index"])
- elif message_type == "load_session":
- self.load_session(data.get("session_id", None))
- elif message_type == "edit_step_at_index":
- self.edit_step_at_index(data.get("user_input", ""), data["index"])
- elif message_type == "set_system_message":
- self.set_system_message(data["message"])
- elif message_type == "set_temperature":
- self.set_temperature(float(data["temperature"]))
- elif message_type == "add_model_for_role":
- self.add_model_for_role(data["role"], data["model_class"], data["model"])
- elif message_type == "set_model_for_role_from_index":
- self.set_model_for_role_from_index(data["role"], data["index"])
- elif message_type == "save_context_group":
- self.save_context_group(
- data["title"], [ContextItem(**item) for item in data["context_items"]]
- )
- elif message_type == "select_context_group":
- self.select_context_group(data["id"])
- elif message_type == "delete_context_group":
- self.delete_context_group(data["id"])
-
- def on_main_input(self, input: str):
- # Do something with user input
- create_async_task(
- self.session.autopilot.accept_user_input(input), self.on_error
- )
-
- def on_reverse_to_index(self, index: int):
- # Reverse the history to the given index
- create_async_task(self.session.autopilot.reverse_to_index(index), self.on_error)
-
- def on_step_user_input(self, input: str, index: int):
- create_async_task(
- self.session.autopilot.give_user_input(input, index), self.on_error
- )
-
- def on_refinement_input(self, input: str, index: int):
- create_async_task(
- self.session.autopilot.accept_refinement_input(input, index), self.on_error
- )
-
- def on_retry_at_index(self, index: int):
- create_async_task(self.session.autopilot.retry_at_index(index), self.on_error)
-
- def on_clear_history(self):
- create_async_task(self.session.autopilot.clear_history(), self.on_error)
-
- def on_delete_at_index(self, index: int):
- create_async_task(self.session.autopilot.delete_at_index(index), self.on_error)
-
- def edit_step_at_index(self, user_input: str, index: int):
- create_async_task(
- self.session.autopilot.edit_step_at_index(user_input, index),
- self.on_error,
- )
-
- def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None):
- create_async_task(
- self.session.autopilot.delete_context_with_ids(ids, index), self.on_error
- )
-
- def on_toggle_adding_highlighted_code(self):
- create_async_task(
- self.session.autopilot.toggle_adding_highlighted_code(), self.on_error
- )
- posthog_logger.capture_event("toggle_adding_highlighted_code", {})
-
- def on_set_editing_at_ids(self, ids: List[str]):
- create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error)
-
- def on_show_logs_at_index(self, index: int):
- name = "Continue Context"
- logs = "\n\n############################################\n\n".join(
- ["This is the prompt that was sent to the LLM during this step"]
- + self.session.autopilot.continue_sdk.history.timeline[index].logs
- )
- create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error
- )
- posthog_logger.capture_event("show_logs_at_index", {})
-
- def show_context_virtual_file(self, index: Optional[int] = None):
- async def async_stuff():
- if index is None:
- context_items = (
- await self.session.autopilot.context_manager.get_selected_items()
- )
- elif index < len(self.session.autopilot.continue_sdk.history.timeline):
- context_items = self.session.autopilot.continue_sdk.history.timeline[
- index
- ].context_used
-
- ctx = "\n\n-----------------------------------\n\n".join(
- ["These are the context items that will be passed to the LLM"]
- + list(map(lambda x: x.content, context_items))
- )
- await self.session.autopilot.ide.showVirtualFile(
- "Continue - Selected Context", ctx
- )
-
- create_async_task(
- async_stuff(),
- self.on_error,
- )
-
- def select_context_item(self, id: str, query: str):
- """Called when user selects an item from the dropdown"""
- create_async_task(
- self.session.autopilot.select_context_item(id, query), self.on_error
- )
-
- def select_context_item_at_index(self, id: str, query: str, index: int):
- """Called when user selects an item from the dropdown for prev UserInputStep"""
- create_async_task(
- self.session.autopilot.select_context_item_at_index(id, query, index),
- self.on_error,
- )
-
- def load_session(self, session_id: Optional[str] = None):
- async def load_and_tell_to_reconnect():
- new_session_id = await session_manager.load_session(
- self.session.session_id, session_id
- )
- await self._send_json(
- "reconnect_at_session", {"session_id": new_session_id}
- )
-
- create_async_task(load_and_tell_to_reconnect(), self.on_error)
-
- posthog_logger.capture_event("load_session", {"session_id": session_id})
-
- def set_current_session_title(self, title: str):
- self.session.autopilot.set_current_session_title(title)
-
- def set_system_message(self, message: str):
- self.session.autopilot.continue_sdk.config.system_message = message
- self.session.autopilot.continue_sdk.models.set_system_message(message)
-
- create_async_task(
- self.session.autopilot.set_config_attr(
- ["system_message"], create_string_node(message)
- ),
- self.on_error,
- )
- posthog_logger.capture_event("set_system_message", {"system_message": message})
-
- def set_temperature(self, temperature: float):
- self.session.autopilot.continue_sdk.config.temperature = temperature
- create_async_task(
- self.session.autopilot.set_config_attr(
- ["temperature"], create_float_node(temperature)
- ),
- self.on_error,
- )
- posthog_logger.capture_event("set_temperature", {"temperature": temperature})
-
- def set_model_for_role_from_index(self, role: str, index: int):
- async def async_stuff():
- models = self.session.autopilot.continue_sdk.config.models
-
- # Set models in SDK
- temp = models.default
- models.default = models.saved[index]
- models.saved[index] = temp
- await self.session.autopilot.continue_sdk.start_model(models.default)
-
- # Set models in config.py
- JOINER = ",\n\t\t"
- models_args = {
- "saved": f"[{JOINER.join([display_llm_class(llm) for llm in models.saved])}]",
- ("default" if role == "*" else role): display_llm_class(models.default),
- }
-
- await self.session.autopilot.set_config_attr(
- ["models"],
- create_obj_node("Models", models_args),
- )
-
- for other_role in ALL_MODEL_ROLES:
- if other_role != "default":
- models.__setattr__(other_role, models.default)
-
- await self.session.autopilot.continue_sdk.update_ui()
-
- create_async_task(async_stuff(), self.on_error)
-
- def add_model_for_role(self, role: str, model_class: str, model: Any):
- models = self.session.autopilot.continue_sdk.config.models
-
- model_copy = model.copy()
- if "api_key" in model_copy:
- del model_copy["api_key"]
- if "hf_token" in model_copy:
- del model_copy["hf_token"]
- posthog_logger.capture_event(
- "select_model_for_role",
- {"role": role, "model_class": model_class, "model": model_copy},
- )
-
- if role == "*":
-
- async def async_stuff():
- # Remove all previous models in roles and place in saved
- saved_models = models.saved
- existing_saved_models = set(
- [display_llm_class(llm) for llm in saved_models]
- )
- for role in ALL_MODEL_ROLES:
- val = models.__getattribute__(role)
- if (
- val is not None
- and display_llm_class(val) not in existing_saved_models
- ):
- saved_models.append(val)
- existing_saved_models.add(display_llm_class(val))
- models.__setattr__(role, None)
-
- # Add the requisite import to config.py
- add_config_import(
- f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
- )
- if "template_messages" in model:
- add_config_import(
- f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}"
- )
-
- # Set and start the new default model
-
- if "template_messages" in model:
- model["template_messages"] = {
- "llama2_template_messages": llama2_template_messages,
- "template_alpaca_messages": template_alpaca_messages,
- }[model["template_messages"]]
- new_model = MODEL_CLASSES[model_class](**model)
- models.default = new_model
- await self.session.autopilot.continue_sdk.start_model(models.default)
-
- # Construct and set the new models object
- JOINER = ",\n\t\t"
- saved_model_strings = set(
- [display_llm_class(llm) for llm in saved_models]
- )
- models_args = {
- "default": display_llm_class(models.default, True),
- "saved": f"[{JOINER.join(saved_model_strings)}]",
- }
-
- await self.session.autopilot.set_config_attr(
- ["models"],
- create_obj_node("Models", models_args),
- )
-
- # Set all roles (in-memory) to the new default model
- for role in ALL_MODEL_ROLES:
- if role != "default":
- models.__setattr__(role, models.default)
-
- # Display setup help
- await self.session.autopilot.continue_sdk.run_step(
- SetupModelStep(model_class=model_class)
- )
-
- create_async_task(async_stuff(), self.on_error)
- else:
- # TODO
- pass
-
- def save_context_group(self, title: str, context_items: List[ContextItem]):
- create_async_task(
- self.session.autopilot.save_context_group(title, context_items),
- self.on_error,
- )
-
- def select_context_group(self, id: str):
- create_async_task(
- self.session.autopilot.select_context_group(id), self.on_error
- )
-
- def delete_context_group(self, id: str):
- create_async_task(
- self.session.autopilot.delete_context_group(id), self.on_error
- )
-
-
-@router.websocket("/ws")
-async def websocket_endpoint(
- websocket: WebSocket, session: Session = Depends(websocket_session)
-):
- try:
- logger.debug(f"Received websocket connection at url: {websocket.url}")
- await websocket.accept()
-
- logger.debug("Session started")
- session_manager.register_websocket(session.session_id, websocket)
- protocol = GUIProtocolServer(session)
- protocol.websocket = websocket
-
- # Update any history that may have happened before connection
- await protocol.session.autopilot.update_subscribers()
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- logger.debug(f"Received GUI message {message}")
- if isinstance(message, str):
- message = json.loads(message)
-
- if "messageType" not in message or "data" not in message:
- continue # :o
- message_type = message["messageType"]
- data = message["data"]
-
- protocol.handle_json(message_type, data)
- except WebSocketDisconnect:
- logger.debug("GUI websocket disconnected")
- except Exception as e:
- # Log, send to PostHog, and send to GUI
- logger.debug(f"ERROR in gui websocket: {e}")
- err_msg = "\n".join(traceback.format_exception(e))
- posthog_logger.capture_event(
- "gui_error",
- {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg},
- )
-
- await session.autopilot.ide.showMessage(err_msg)
-
- raise e
- finally:
- logger.debug("Closing gui websocket")
- if websocket.client_state != WebSocketState.DISCONNECTED:
- await websocket.close()
-
- await session_manager.persist_session(session.session_id)
- await session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
deleted file mode 100644
index 6a4dc738..00000000
--- a/continuedev/src/continuedev/server/ide.py
+++ /dev/null
@@ -1,673 +0,0 @@
-# This is a separate server from server/main.py
-import asyncio
-import json
-import os
-import traceback
-import uuid
-from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, TypeVar, Union
-
-import nest_asyncio
-from fastapi import APIRouter, WebSocket
-from pydantic import BaseModel
-from starlette.websockets import WebSocketDisconnect, WebSocketState
-from uvicorn.main import Server
-
-from ..core.main import ContinueCustomException
-from ..libs.util.create_async_task import create_async_task
-from ..libs.util.devdata import dev_data_logger
-from ..libs.util.logging import logger
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..libs.util.telemetry import posthog_logger
-from ..models.filesystem import (
- EditDiff,
- FileSystem,
- RangeInFile,
- RangeInFileWithContents,
- RealFileSystem,
-)
-from ..models.filesystem_edit import (
- AddDirectory,
- AddFile,
- DeleteDirectory,
- DeleteFile,
- FileEdit,
- FileEditWithFullContents,
- FileSystemEdit,
- RenameDirectory,
- RenameFile,
- SequentialFileSystemEdit,
-)
-from ..plugins.steps.core.core import DisplayErrorStep
-from .gui import session_manager
-from .ide_protocol import AbstractIdeProtocolServer
-from .session_manager import SessionManager
-
-nest_asyncio.apply()
-
-
-router = APIRouter(prefix="/ide", tags=["ide"])
-
-
-# Graceful shutdown by closing websockets
-original_handler = Server.handle_exit
-
-
-class AppStatus:
- should_exit = False
-
- @staticmethod
- def handle_exit(*args, **kwargs):
- AppStatus.should_exit = True
- logger.debug("Shutting down")
- original_handler(*args, **kwargs)
-
-
-Server.handle_exit = AppStatus.handle_exit
-
-
-# TYPES #
-
-
-class FileEditsUpdate(BaseModel):
- fileEdits: List[FileEditWithFullContents]
-
-
-class OpenFilesResponse(BaseModel):
- openFiles: List[str]
-
-
-class VisibleFilesResponse(BaseModel):
- visibleFiles: List[str]
-
-
-class HighlightedCodeResponse(BaseModel):
- highlightedCode: List[RangeInFile]
-
-
-class ShowSuggestionRequest(BaseModel):
- suggestion: FileEdit
-
-
-class ShowSuggestionResponse(BaseModel):
- suggestion: FileEdit
- accepted: bool
-
-
-class ReadFileResponse(BaseModel):
- contents: str
-
-
-class EditFileResponse(BaseModel):
- fileEdit: FileEditWithFullContents
-
-
-class WorkspaceDirectoryResponse(BaseModel):
- workspaceDirectory: str
-
-
-class GetUserSecretResponse(BaseModel):
- value: str
-
-
-class RunCommandResponse(BaseModel):
- output: str = ""
-
-
-class UniqueIdResponse(BaseModel):
- uniqueId: str
-
-
-class TerminalContentsResponse(BaseModel):
- contents: str
-
-
-class ListDirectoryContentsResponse(BaseModel):
- contents: List[str]
-
-
-class FileExistsResponse(BaseModel):
- exists: bool
-
-
-T = TypeVar("T", bound=BaseModel)
-
-
-class cached_property_no_none:
- def __init__(self, func):
- self.func = func
-
- def __get__(self, instance, owner):
- if instance is None:
- return self
- value = self.func(instance)
- if value is not None:
- setattr(instance, self.func.__name__, value)
- return value
-
- def __repr__(self):
- return f"<cached_property_no_none '{self.func.__name__}'>"
-
-
-class IdeProtocolServer(AbstractIdeProtocolServer):
- websocket: WebSocket
- session_manager: SessionManager
- sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
- session_id: Union[str, None] = None
-
- ide_info: Optional[Dict] = None
-
- def __init__(self, session_manager: SessionManager, websocket: WebSocket):
- self.websocket = websocket
- self.session_manager = session_manager
-
- workspace_directory: str = None
- unique_id: str = None
-
- async def initialize(self, session_id: str) -> List[str]:
- self.session_id = session_id
- await self._send_json("workspaceDirectory", {})
- await self._send_json("uniqueId", {})
- await self._send_json("ide", {})
- other_msgs = []
- while True:
- msg_string = await self.websocket.receive_text()
- message = json.loads(msg_string)
- if "messageType" not in message or "data" not in message:
- continue # <-- hey that's the name of this repo!
- message_type = message["messageType"]
- data = message["data"]
- logger.debug(f"Received message while initializing {message_type}")
- if message_type == "workspaceDirectory":
- self.workspace_directory = data["workspaceDirectory"]
- elif message_type == "uniqueId":
- self.unique_id = data["uniqueId"]
- elif message_type == "ide":
- self.ide_info = data
- else:
- other_msgs.append(msg_string)
-
- if self.workspace_directory is not None and self.unique_id is not None:
- break
- return other_msgs
-
- async def _send_json(self, message_type: str, data: Any):
- # TODO: You breakpointed here, set it to disconnected, and then saw
- # that even after reloading, it couldn't connect the server.
- # Is this because there is an IDE registered without a websocket?
- # This shouldn't count as registered in that case.
- try:
- if self.websocket.application_state == WebSocketState.DISCONNECTED:
- logger.debug(
- f"Tried to send message, but websocket is disconnected: {message_type}"
- )
- return
- # logger.debug(f"Sending IDE message: {message_type}")
- await self.websocket.send_json({"messageType": message_type, "data": data})
- except RuntimeError as e:
- logger.warning(f"Error sending IDE message, websocket probably closed: {e}")
-
- async def _receive_json(
- self, message_type: str, timeout: int = 20, message=None
- ) -> Any:
- try:
- return await asyncio.wait_for(
- self.sub_queue.get(message_type), timeout=timeout
- )
- except asyncio.TimeoutError:
- raise ContinueCustomException(
- title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}",
- message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}",
- )
-
- async def _send_and_receive_json(
- self, data: Any, resp_model: Type[T], message_type: str
- ) -> T:
- await self._send_json(message_type, data)
- resp = await self._receive_json(message_type, message=data)
- return resp_model.parse_obj(resp)
-
- async def handle_json(self, message_type: str, data: Any):
- if message_type == "getSessionId":
- await self.getSessionId()
- elif message_type == "setFileOpen":
- await self.setFileOpen(data["filepath"], data["open"])
- elif message_type == "setSuggestionsLocked":
- await self.setSuggestionsLocked(data["filepath"], data["locked"])
- elif message_type == "fileEdits":
- fileEdits = list(
- map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])
- )
- self.onFileEdits(fileEdits)
- elif message_type == "highlightedCodePush":
- self.onHighlightedCodeUpdate(
- [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]],
- edit=data.get("edit", False),
- )
- elif message_type == "commandOutput":
- output = data["output"]
- self.onCommandOutput(output)
- elif message_type == "debugTerminal":
- content = data["contents"]
- self.onDebugTerminal(content)
- elif message_type == "acceptRejectSuggestion":
- self.onAcceptRejectSuggestion(data["accepted"])
- elif message_type == "acceptRejectDiff":
- self.onAcceptRejectDiff(data["accepted"], data["stepIndex"])
- elif message_type == "mainUserInput":
- self.onMainUserInput(data["input"])
- elif message_type == "deleteAtIndex":
- self.onDeleteAtIndex(data["index"])
- elif message_type in [
- "highlightedCode",
- "openFiles",
- "visibleFiles",
- "readFile",
- "editFile",
- "getUserSecret",
- "runCommand",
- "getTerminalContents",
- "listDirectoryContents",
- "fileExists",
- ]:
- self.sub_queue.post(message_type, data)
- elif message_type == "workspaceDirectory":
- self.workspace_directory = data["workspaceDirectory"]
- elif message_type == "uniqueId":
- self.unique_id = data["uniqueId"]
- elif message_type == "ide":
- self.ide_info = data
- elif message_type == "filesCreated":
- self.onFilesCreated(data["filepaths"])
- elif message_type == "filesDeleted":
- self.onFilesDeleted(data["filepaths"])
- elif message_type == "filesRenamed":
- self.onFilesRenamed(data["old_filepaths"], data["new_filepaths"])
- elif message_type == "fileSaved":
- self.onFileSaved(data["filepath"], data["contents"])
- else:
- raise ValueError("Unknown message type", message_type)
-
- async def showSuggestion(self, file_edit: FileEdit):
- await self._send_json("showSuggestion", {"edit": file_edit.dict()})
-
- async def showDiff(self, filepath: str, replacement: str, step_index: int):
- await self._send_json(
- "showDiff",
- {
- "filepath": filepath,
- "replacement": replacement,
- "step_index": step_index,
- },
- )
-
- async def setFileOpen(self, filepath: str, open: bool = True):
- # Autopilot needs access to this.
- await self._send_json("setFileOpen", {"filepath": filepath, "open": open})
-
- async def showMessage(self, message: str):
- await self._send_json("showMessage", {"message": message})
-
- async def showVirtualFile(self, name: str, contents: str):
- await self._send_json("showVirtualFile", {"name": name, "contents": contents})
-
- async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
- # Lock suggestions in the file so they don't ruin the offset before others are inserted
- await self._send_json(
- "setSuggestionsLocked", {"filepath": filepath, "locked": locked}
- )
-
- async def getSessionId(self):
- new_session = await asyncio.wait_for(
- self.session_manager.new_session(self, self.session_id), timeout=5
- )
- session_id = new_session.session_id
- logger.debug(f"Sending session id: {session_id}")
- await self._send_json("getSessionId", {"sessionId": session_id})
-
- async def highlightCode(self, range_in_file: RangeInFile, color: str = "#00ff0022"):
- await self._send_json(
- "highlightCode", {"rangeInFile": range_in_file.dict(), "color": color}
- )
-
- async def runCommand(self, command: str) -> str:
- return (
- await self._send_and_receive_json(
- {"command": command}, RunCommandResponse, "runCommand"
- )
- ).output
-
- async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
- ids = [str(uuid.uuid4()) for _ in suggestions]
- for i in range(len(suggestions)):
- self._send_json(
- "showSuggestion", {"suggestion": suggestions[i], "suggestionId": ids[i]}
- )
- responses = await asyncio.gather(
- *[
- self._receive_json(ShowSuggestionResponse)
- for i in range(len(suggestions))
- ]
- ) # WORKING ON THIS FLOW HERE. Fine now to just await for response, instead of doing something fancy with a "waiting" state on the autopilot.
- # Just need connect the suggestionId to the IDE (and the gui)
- return any([r.accepted for r in responses])
-
- def on_error(self, e: Exception) -> Coroutine:
- err_msg = "\n".join(traceback.format_exception(e))
- e_title = e.__str__() or e.__repr__()
- return self.showMessage(f"Error in Continue server: {e_title}\n {err_msg}")
-
- def onAcceptRejectSuggestion(self, accepted: bool):
- posthog_logger.capture_event("accept_reject_suggestion", {"accepted": accepted})
- dev_data_logger.capture("accept_reject_suggestion", {"accepted": accepted})
-
- def onAcceptRejectDiff(self, accepted: bool, step_index: int):
- posthog_logger.capture_event("accept_reject_diff", {"accepted": accepted})
- dev_data_logger.capture("accept_reject_diff", {"accepted": accepted})
-
- if not accepted:
- if autopilot := self.__get_autopilot():
- create_async_task(
- autopilot.reject_diff(step_index),
- self.on_error,
- )
-
- def onFileSystemUpdate(self, update: FileSystemEdit):
- # Access to Autopilot (so SessionManager)
- pass
-
- def onCloseGUI(self, session_id: str):
- # Access to SessionManager
- pass
-
- def onOpenGUIRequest(self):
- pass
-
- def __get_autopilot(self):
- if self.session_id not in self.session_manager.sessions:
- return None
-
- autopilot = self.session_manager.sessions[self.session_id].autopilot
- return autopilot if autopilot.started else None
-
- def onFileEdits(self, edits: List[FileEditWithFullContents]):
- if autopilot := self.__get_autopilot():
- pass
-
- def onDeleteAtIndex(self, index: int):
- if autopilot := self.__get_autopilot():
- create_async_task(autopilot.delete_at_index(index), self.on_error)
-
- def onCommandOutput(self, output: str):
- if autopilot := self.__get_autopilot():
- create_async_task(autopilot.handle_command_output(output), self.on_error)
-
- def onDebugTerminal(self, content: str):
- if autopilot := self.__get_autopilot():
- create_async_task(autopilot.handle_debug_terminal(content), self.on_error)
-
- def onHighlightedCodeUpdate(
- self,
- range_in_files: List[RangeInFileWithContents],
- edit: Optional[bool] = False,
- ):
- if autopilot := self.__get_autopilot():
- create_async_task(
- autopilot.handle_highlighted_code(range_in_files, edit), self.on_error
- )
-
- ## Subscriptions ##
-
- _files_created_callbacks = []
- _files_deleted_callbacks = []
- _files_renamed_callbacks = []
- _file_saved_callbacks = []
-
- def call_callback(self, callback, *args, **kwargs):
- if asyncio.iscoroutinefunction(callback):
- create_async_task(callback(*args, **kwargs), self.on_error)
- else:
- callback(*args, **kwargs)
-
- def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]):
- self._files_created_callbacks.append(callback)
-
- def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]):
- self._files_deleted_callbacks.append(callback)
-
- def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]):
- self._files_renamed_callbacks.append(callback)
-
- def subscribeToFileSaved(self, callback: Callable[[str, str], None]):
- self._file_saved_callbacks.append(callback)
-
- def onFilesCreated(self, filepaths: List[str]):
- for callback in self._files_created_callbacks:
- self.call_callback(callback, filepaths)
-
- def onFilesDeleted(self, filepaths: List[str]):
- for callback in self._files_deleted_callbacks:
- self.call_callback(callback, filepaths)
-
- def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]):
- for callback in self._files_renamed_callbacks:
- self.call_callback(callback, old_filepaths, new_filepaths)
-
- def onFileSaved(self, filepath: str, contents: str):
- for callback in self._file_saved_callbacks:
- self.call_callback(callback, filepath, contents)
-
- ## END Subscriptions ##
-
- def onMainUserInput(self, input: str):
- if autopilot := self.__get_autopilot():
- create_async_task(autopilot.accept_user_input(input), self.on_error)
-
- # Request information. Session doesn't matter.
- async def getOpenFiles(self) -> List[str]:
- resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
- return resp.openFiles
-
- async def getVisibleFiles(self) -> List[str]:
- resp = await self._send_and_receive_json(
- {}, VisibleFilesResponse, "visibleFiles"
- )
- return resp.visibleFiles
-
- async def getTerminalContents(self, commands: int = -1) -> str:
- """Get the contents of the terminal, up to the last 'commands' commands, or all if commands is -1"""
- resp = await self._send_and_receive_json(
- {"commands": commands}, TerminalContentsResponse, "getTerminalContents"
- )
- return resp.contents.strip()
-
- async def getHighlightedCode(self) -> List[RangeInFile]:
- resp = await self._send_and_receive_json(
- {}, HighlightedCodeResponse, "highlightedCode"
- )
- return resp.highlightedCode
-
- async def readFile(self, filepath: str) -> str:
- """Read a file"""
- resp = await self._send_and_receive_json(
- {"filepath": filepath}, ReadFileResponse, "readFile"
- )
- return resp.contents
-
- async def fileExists(self, filepath: str) -> str:
- """Check whether file exists"""
- resp = await self._send_and_receive_json(
- {"filepath": filepath}, FileExistsResponse, "fileExists"
- )
- return resp.exists
-
- async def getUserSecret(self, key: str) -> str:
- """Get a user secret"""
- try:
- resp = await self._send_and_receive_json(
- {"key": key}, GetUserSecretResponse, "getUserSecret"
- )
- return resp.value
- except Exception as e:
- logger.debug(f"Error getting user secret: {e}")
- return ""
-
- async def saveFile(self, filepath: str):
- """Save a file"""
- await self._send_json("saveFile", {"filepath": filepath})
-
- async def readRangeInFile(self, range_in_file: RangeInFile) -> str:
- """Read a range in a file"""
- full_contents = await self.readFile(range_in_file.filepath)
- return FileSystem.read_range_in_str(full_contents, range_in_file.range)
-
- async def editFile(self, edit: FileEdit) -> FileEditWithFullContents:
- """Edit a file"""
- resp = await self._send_and_receive_json(
- {"edit": edit.dict()}, EditFileResponse, "editFile"
- )
- return resp.fileEdit
-
- async def listDirectoryContents(
- self, directory: str, recursive: bool = False
- ) -> List[str]:
- """List the contents of a directory"""
- resp = await self._send_and_receive_json(
- {"directory": directory, "recursive": recursive},
- ListDirectoryContentsResponse,
- "listDirectoryContents",
- )
- return resp.contents
-
- async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff:
- """Apply a file edit"""
- backward = None
- fs = RealFileSystem()
- if isinstance(edit, FileEdit):
- file_edit = await self.editFile(edit)
- _, diff = FileSystem.apply_edit_to_str(
- file_edit.fileContents, file_edit.fileEdit
- )
- backward = diff.backward
- elif isinstance(edit, AddFile):
- fs.write(edit.filepath, edit.content)
- backward = DeleteFile(filepath=edit.filepath)
- elif isinstance(edit, DeleteFile):
- contents = await self.readFile(edit.filepath)
- backward = AddFile(filepath=edit.filepath, content=contents)
- fs.delete_file(edit.filepath)
- elif isinstance(edit, RenameFile):
- fs.rename_file(edit.filepath, edit.new_filepath)
- backward = RenameFile(
- filepath=edit.new_filepath, new_filepath=edit.filepath
- )
- elif isinstance(edit, AddDirectory):
- fs.add_directory(edit.path)
- backward = DeleteDirectory(path=edit.path)
- elif isinstance(edit, DeleteDirectory):
- # This isn't atomic!
- backward_edits = []
- for root, dirs, files in os.walk(edit.path, topdown=False):
- for f in files:
- path = os.path.join(root, f)
- edit_diff = await self.applyFileSystemEdit(
- DeleteFile(filepath=path)
- )
- backward_edits.append(edit_diff)
- for d in dirs:
- path = os.path.join(root, d)
- edit_diff = await self.applyFileSystemEdit(
- DeleteDirectory(path=path)
- )
- backward_edits.append(edit_diff)
-
- edit_diff = await self.applyFileSystemEdit(DeleteDirectory(path=edit.path))
- backward_edits.append(edit_diff)
- backward_edits.reverse()
- backward = SequentialFileSystemEdit(edits=backward_edits)
- elif isinstance(edit, RenameDirectory):
- fs.rename_directory(edit.path, edit.new_path)
- backward = RenameDirectory(path=edit.new_path, new_path=edit.path)
- elif isinstance(edit, FileSystemEdit):
- diffs = []
- for edit in edit.next_edit():
- edit_diff = await self.applyFileSystemEdit(edit)
- diffs.append(edit_diff)
- backward = EditDiff.from_sequence(diffs=diffs).backward
- else:
- raise TypeError("Unknown FileSystemEdit type: " + str(type(edit)))
-
- return EditDiff(forward=edit, backward=backward)
-
-
-@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
- try:
- # Accept the websocket connection
- await websocket.accept()
- logger.debug(f"Accepted websocket connection from {websocket.client}")
- await websocket.send_json({"messageType": "connected", "data": {}})
-
- # Message handler
- def handle_msg(msg):
- try:
- message = json.loads(msg)
- except json.JSONDecodeError:
- logger.critical(f"Error decoding json: {msg}")
- return
-
- if "messageType" not in message or "data" not in message:
- return
- message_type = message["messageType"]
- data = message["data"]
-
- # logger.debug(f"Received IDE message: {message_type}")
- create_async_task(
- ideProtocolServer.handle_json(message_type, data),
- ideProtocolServer.on_error,
- )
-
- # Initialize the IDE Protocol Server
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
- if session_id is not None:
- session_manager.registered_ides[session_id] = ideProtocolServer
- other_msgs = await ideProtocolServer.initialize(session_id)
- posthog_logger.capture_event(
- "session_started", {"session_id": ideProtocolServer.session_id}
- )
-
- for other_msg in other_msgs:
- handle_msg(other_msg)
-
- # Handle messages
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- handle_msg(message)
-
- except WebSocketDisconnect:
- logger.debug("IDE websocket disconnected")
- except Exception as e:
- logger.debug(f"Error in ide websocket: {e}")
- err_msg = "\n".join(traceback.format_exception(e))
- posthog_logger.capture_event(
- "gui_error",
- {"error_title": e.__str__() or e.__repr__(), "error_message": err_msg},
- )
-
- if session_id is not None and session_id in session_manager.sessions:
- await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(
- DisplayErrorStep.from_exception(e)
- )
- elif ideProtocolServer is not None:
- await ideProtocolServer.showMessage(f"Error in Continue server: {err_msg}")
-
- raise e
- finally:
- logger.debug("Closing ide websocket")
- if websocket.client_state != WebSocketState.DISCONNECTED:
- await websocket.close()
-
- posthog_logger.capture_event(
- "session_ended", {"session_id": ideProtocolServer.session_id}
- )
- if ideProtocolServer.session_id in session_manager.registered_ides:
- session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
deleted file mode 100644
index 832dd338..00000000
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional, Union
-
-from fastapi import WebSocket
-
-from ..models.filesystem import RangeInFile, RangeInFileWithContents
-from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit
-
-
-class AbstractIdeProtocolServer(ABC):
- websocket: WebSocket
- session_id: Union[str, None]
- ide_info: Optional[Dict] = None
-
- @abstractmethod
- async def handle_json(self, data: Any):
- """Handle a json message"""
-
- @abstractmethod
- def showSuggestion(self, file_edit: FileEdit):
- """Show a suggestion to the user"""
-
- @abstractmethod
- async def setFileOpen(self, filepath: str, open: bool = True):
- """Set whether a file is open"""
-
- @abstractmethod
- async def showMessage(self, message: str):
- """Show a message to the user"""
-
- @abstractmethod
- async def showVirtualFile(self, name: str, contents: str):
- """Show a virtual file"""
-
- @abstractmethod
- async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
- """Set whether suggestions are locked"""
-
- @abstractmethod
- async def getSessionId(self):
- """Get a new session ID"""
-
- @abstractmethod
- async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
- """Show suggestions to the user and wait for a response"""
-
- @abstractmethod
- def onAcceptRejectSuggestion(self, accepted: bool):
- """Called when the user accepts or rejects a suggestion"""
-
- @abstractmethod
- def onFileSystemUpdate(self, update: FileSystemEdit):
- """Called when a file system update is received"""
-
- @abstractmethod
- def onCloseGUI(self, session_id: str):
- """Called when a GUI is closed"""
-
- @abstractmethod
- def onOpenGUIRequest(self):
- """Called when a GUI is requested to be opened"""
-
- @abstractmethod
- async def getOpenFiles(self) -> List[str]:
- """Get a list of open files"""
-
- @abstractmethod
- async def getVisibleFiles(self) -> List[str]:
- """Get a list of visible files"""
-
- @abstractmethod
- async def getHighlightedCode(self) -> List[RangeInFile]:
- """Get a list of highlighted code"""
-
- @abstractmethod
- async def readFile(self, filepath: str) -> str:
- """Read a file"""
-
- @abstractmethod
- async def readRangeInFile(self, range_in_file: RangeInFile) -> str:
- """Read a range in a file"""
-
- @abstractmethod
- async def editFile(self, edit: FileEdit):
- """Edit a file"""
-
- @abstractmethod
- async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff:
- """Apply a file edit"""
-
- @abstractmethod
- async def saveFile(self, filepath: str):
- """Save a file"""
-
- @abstractmethod
- async def getUserSecret(self, key: str):
- """Get a user secret"""
-
- @abstractmethod
- async def highlightCode(self, range_in_file: RangeInFile, color: str):
- """Highlight code"""
-
- @abstractmethod
- async def runCommand(self, command: str) -> str:
- """Run a command"""
-
- @abstractmethod
- def onHighlightedCodeUpdate(
- self,
- range_in_files: List[RangeInFileWithContents],
- edit: Optional[bool] = False,
- ):
- """Called when highlighted code is updated"""
-
- @abstractmethod
- def onDeleteAtIndex(self, index: int):
- """Called when a step is deleted at a given index"""
-
- @abstractmethod
- async def showDiff(self, filepath: str, replacement: str, step_index: int):
- """Show a diff"""
-
- @abstractmethod
- def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]):
- """Subscribe to files created event"""
-
- @abstractmethod
- def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]):
- """Subscribe to files deleted event"""
-
- @abstractmethod
- def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]):
- """Subscribe to files renamed event"""
-
- @abstractmethod
- def subscribeToFileSaved(self, callback: Callable[[str, str], None]):
- """Subscribe to file saved event"""
-
- @abstractmethod
- def onFilesCreated(self, filepaths: List[str]):
- """Called when files are created"""
-
- @abstractmethod
- def onFilesDeleted(self, filepaths: List[str]):
- """Called when files are deleted"""
-
- @abstractmethod
- def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]):
- """Called when files are renamed"""
-
- @abstractmethod
- def onFileSaved(self, filepath: str, contents: str):
- """Called when a file is saved"""
-
- @abstractmethod
- async def listDirectoryContents(
- self, directory: str, recursive: bool = False
- ) -> List[str]:
- """List directory contents"""
-
- @abstractmethod
- async def fileExists(self, filepath: str) -> str:
- """Check if a file exists"""
-
- @abstractmethod
- async def getTerminalContents(self, commands: int = -1) -> str:
- """Get the terminal contents"""
-
- workspace_directory: str
- unique_id: str
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
deleted file mode 100644
index c5540d7d..00000000
--- a/continuedev/src/continuedev/server/main.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import argparse
-import asyncio
-import atexit
-from contextlib import asynccontextmanager
-from typing import Optional
-
-import uvicorn
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-
-from ..libs.util.create_async_task import create_async_task
-from ..libs.util.logging import logger
-from .gui import router as gui_router
-from .ide import router as ide_router
-from .meilisearch_server import start_meilisearch, stop_meilisearch
-from .session_manager import router as sessions_router
-from .session_manager import session_manager
-
-meilisearch_url_global = None
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- async def on_err(e):
- logger.warning(f"Error starting MeiliSearch: {e}")
-
- try:
- # start meilisearch without blocking server startup
- create_async_task(start_meilisearch(url=meilisearch_url_global), on_err)
- except Exception as e:
- logger.warning(f"Error starting MeiliSearch: {e}")
-
- yield
- stop_meilisearch()
-
-
-app = FastAPI(lifespan=lifespan)
-
-app.include_router(ide_router)
-app.include_router(gui_router)
-app.include_router(sessions_router)
-
-# Add CORS support
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-
-@app.get("/health")
-def health():
- logger.debug("Health check")
- return {"status": "ok"}
-
-
-def run_server(
- port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None
-):
- try:
- global meilisearch_url_global
-
- meilisearch_url_global = meilisearch_url
-
- config = uvicorn.Config(app, host=host, port=port)
- server = uvicorn.Server(config)
- server.run()
- except PermissionError as e:
- logger.critical(
- f"Error starting Continue server: {e}. "
- f"This means that port {port} is already in use, and is usually caused by another instance of the Continue server already running."
- )
- cleanup()
- raise e
-
- except Exception as e:
- logger.critical(f"Error starting Continue server: {e}")
- cleanup()
- raise e
-
-
-async def cleanup_coroutine():
- logger.debug("------ Cleaning Up ------")
- for session_id in session_manager.sessions:
- await session_manager.persist_session(session_id)
-
-
-def cleanup():
- loop = asyncio.new_event_loop()
- loop.run_until_complete(cleanup_coroutine())
- loop.close()
-
-
-atexit.register(cleanup)
-
-if __name__ == "__main__":
- try:
- # add cli arg for server port
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--port", help="server port", type=int, default=65432)
- parser.add_argument("--host", help="server host", type=str, default="127.0.0.1")
- args = parser.parse_args()
- except Exception as e:
- logger.critical(f"Error parsing command line arguments: {e}")
- raise e
-
- run_server(args.port, args.host)
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
deleted file mode 100644
index 6ce4d61c..00000000
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ /dev/null
@@ -1,195 +0,0 @@
-import asyncio
-import os
-import shutil
-import subprocess
-from typing import Optional
-
-import aiofiles
-import aiohttp
-import psutil
-from meilisearch_python_async import Client
-
-from ..libs.util.logging import logger
-from ..libs.util.paths import getMeilisearchExePath, getServerFolderPath
-
-
-async def download_file(url: str, filename: str):
- async with aiohttp.ClientSession() as session:
- async with session.get(url) as resp:
- if resp.status == 200:
- f = await aiofiles.open(filename, mode="wb")
- await f.write(await resp.read())
- await f.close()
-
-
-async def download_meilisearch():
- """
- Downloads MeiliSearch.
- """
-
- serverPath = getServerFolderPath()
- logger.debug("Downloading MeiliSearch...")
-
- if os.name == "nt":
- download_url = "https://github.com/meilisearch/meilisearch/releases/download/v1.3.2/meilisearch-windows-amd64.exe"
- download_path = getMeilisearchExePath()
- if not os.path.exists(download_path):
- await download_file(download_url, download_path)
- # subprocess.run(
- # f"curl -L {download_url} -o {download_path}",
- # shell=True,
- # check=True,
- # cwd=serverPath,
- # )
- else:
- subprocess.run(
- "curl -L https://install.meilisearch.com | sh",
- shell=True,
- check=True,
- cwd=serverPath,
- )
-
-
-async def ensure_meilisearch_installed() -> bool:
- """
- Checks if MeiliSearch is installed.
-
- Returns a bool indicating whether it was installed to begin with.
- """
- serverPath = getServerFolderPath()
- meilisearchPath = getMeilisearchExePath()
- dumpsPath = os.path.join(serverPath, "dumps")
- dataMsPath = os.path.join(serverPath, "data.ms")
-
- paths = [meilisearchPath, dumpsPath, dataMsPath]
-
- existing_paths = set()
- non_existing_paths = set()
- for path in paths:
- if os.path.exists(path):
- existing_paths.add(path)
- else:
- non_existing_paths.add(path)
-
- if len(non_existing_paths) > 0:
- # Clear the meilisearch binary
- if meilisearchPath in existing_paths:
- try:
- os.remove(meilisearchPath)
- except:
- pass
- existing_paths.remove(meilisearchPath)
-
- await download_meilisearch()
-
- # Clear the existing directories
- for p in existing_paths:
- shutil.rmtree(p, ignore_errors=True)
-
- return False
-
- return True
-
-
-meilisearch_process = None
-DEFAULT_MEILISEARCH_URL = "http://localhost:7700"
-meilisearch_url = DEFAULT_MEILISEARCH_URL
-
-
-def get_meilisearch_url():
- return meilisearch_url
-
-
-async def check_meilisearch_running() -> bool:
- """
- Checks if MeiliSearch is running.
- """
-
- try:
- async with Client(meilisearch_url) as client:
- try:
- resp = await client.health()
- if resp.status != "available":
- return False
- return True
- except Exception:
- return False
- except Exception:
- return False
-
-
-async def poll_meilisearch_running(frequency: int = 0.1) -> bool:
- """
- Polls MeiliSearch to see if it is running.
- """
- while True:
- if await check_meilisearch_running():
- return True
- await asyncio.sleep(frequency)
-
-
-async def start_meilisearch(url: Optional[str] = None):
- """
- Starts the MeiliSearch server, wait for it.
- """
- global meilisearch_process, meilisearch_url
-
- if url is not None:
- logger.debug("Using MeiliSearch at URL: " + url)
- meilisearch_url = url
- return
-
- serverPath = getServerFolderPath()
-
- # Check if MeiliSearch is installed, if not download
- was_already_installed = await ensure_meilisearch_installed()
-
- # Check if MeiliSearch is running
- if not await check_meilisearch_running() or not was_already_installed:
- logger.debug("Starting MeiliSearch...")
- binary_name = "meilisearch" if os.name == "nt" else "./meilisearch"
- meilisearch_process = subprocess.Popen(
- [binary_name, "--no-analytics"],
- cwd=serverPath,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.STDOUT,
- close_fds=True,
- start_new_session=True,
- shell=True,
- )
-
- logger.debug("Meilisearch started")
-
-
-def stop_meilisearch():
- """
- Stops the MeiliSearch server.
- """
- global meilisearch_process
- if meilisearch_process is not None:
- meilisearch_process.terminate()
- meilisearch_process.wait()
- meilisearch_process = None
-
-
-def kill_proc(port):
- for proc in psutil.process_iter():
- try:
- for conns in proc.connections(kind="inet"):
- if conns.laddr.port == port:
- proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL
- except psutil.AccessDenied:
- logger.warning(f"Failed to kill process on port {port} (access denied)")
- return
- except psutil.ZombieProcess:
- logger.warning(f"Failed to kill process on port {port} (zombie process)")
- return
- except psutil.NoSuchProcess:
- logger.warning(f"Failed to kill process on port {port} (no such process)")
- return
-
-
-async def restart_meilisearch():
- stop_meilisearch()
- kill_proc(7700)
- await start_meilisearch(url=meilisearch_url)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
deleted file mode 100644
index f0080104..00000000
--- a/continuedev/src/continuedev/server/session_manager.py
+++ /dev/null
@@ -1,192 +0,0 @@
-import json
-import os
-import traceback
-from typing import Any, Coroutine, Dict, Optional, Union
-from uuid import uuid4
-
-from fastapi import APIRouter, WebSocket
-from fastapi.websockets import WebSocketState
-
-from ..core.autopilot import Autopilot
-from ..core.config import ContinueConfig
-from ..core.main import FullState
-from ..libs.util.create_async_task import create_async_task
-from ..libs.util.logging import logger
-from ..libs.util.paths import (
- getSessionFilePath,
- getSessionsFolderPath,
- getSessionsListFilePath,
-)
-from .ide_protocol import AbstractIdeProtocolServer
-
-router = APIRouter(prefix="/sessions", tags=["sessions"])
-
-
-class Session:
- session_id: str
- autopilot: Autopilot
- # The GUI websocket for the session
- ws: Union[WebSocket, None]
-
- def __init__(self, session_id: str, autopilot: Autopilot):
- self.session_id = session_id
- self.autopilot = autopilot
- self.ws = None
-
-
-class SessionManager:
- sessions: Dict[str, Session] = {}
- # Mapping of session_id to IDE, where the IDE is still alive
- registered_ides: Dict[str, AbstractIdeProtocolServer] = {}
-
- async def get_session(self, session_id: str) -> Session:
- if session_id not in self.sessions:
- # Check then whether it is persisted by listing all files in the sessions folder
- # And only if the IDE is still alive
- sessions_folder = getSessionsFolderPath()
- session_files = os.listdir(sessions_folder)
- if (
- f"{session_id}.json" in session_files
- and session_id in self.registered_ides
- ):
- if self.registered_ides[session_id].session_id is not None:
- return await self.new_session(
- self.registered_ides[session_id], session_id=session_id
- )
-
- raise KeyError("Session ID not recognized", session_id)
- return self.sessions[session_id]
-
- async def new_session(
- self,
- ide: AbstractIdeProtocolServer,
- session_id: Optional[str] = None,
- config: Optional[ContinueConfig] = None,
- ) -> Session:
- logger.debug(f"New session: {session_id}")
-
- # Load the persisted state (not being used right now)
- full_state = None
- if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
- with open(getSessionFilePath(session_id), "r") as f:
- full_state = FullState(**json.load(f))
-
- # Register the session and ide (do this first so that the autopilot can access the session)
- autopilot = Autopilot(ide=ide)
- session_id = session_id or str(uuid4())
- ide.session_id = session_id
- session = Session(session_id=session_id, autopilot=autopilot)
- self.sessions[session_id] = session
- self.registered_ides[session_id] = ide
-
- # Set up the autopilot to update the GUI
- async def on_update(state: FullState):
- await session_manager.send_ws_data(
- session_id, "state_update", {"state": state.dict()}
- )
-
- autopilot.on_update(on_update)
-
- # Start the autopilot (must be after session is added to sessions) and the policy
- try:
- await autopilot.start(full_state=full_state, config=config)
- except Exception as e:
- await ide.on_error(e)
-
- def on_error(e: Exception) -> Coroutine:
- err_msg = "\n".join(traceback.format_exception(e))
- return ide.showMessage(f"Error in Continue server: {err_msg}")
-
- create_async_task(autopilot.run_policy(), on_error)
- return session
-
- async def remove_session(self, session_id: str):
- logger.debug(f"Removing session: {session_id}")
- if session_id in self.sessions:
- if (
- session_id in self.registered_ides
- and self.registered_ides[session_id] is not None
- ):
- ws_to_close = self.registered_ides[session_id].websocket
- if (
- ws_to_close is not None
- and ws_to_close.client_state != WebSocketState.DISCONNECTED
- ):
- await self.sessions[session_id].autopilot.ide.websocket.close()
-
- del self.sessions[session_id]
-
- async def persist_session(self, session_id: str):
- """Save the session's FullState as a json file"""
- full_state = await self.sessions[session_id].autopilot.get_full_state()
- if full_state.session_info is None:
- return
-
- with open(getSessionFilePath(session_id), "w") as f:
- json.dump(full_state.dict(), f)
-
- # Read and update the sessions list
- with open(getSessionsListFilePath(), "r") as f:
- try:
- sessions_list = json.load(f)
- except json.JSONDecodeError:
- raise Exception(
- f"It looks like there is a JSON formatting error in your sessions.json file ({getSessionsListFilePath()}). Please fix this before creating a new session."
- )
-
- session_ids = [s["session_id"] for s in sessions_list]
- if session_id not in session_ids:
- sessions_list.append(full_state.session_info.dict())
-
- for session_info in sessions_list:
- if "workspace_directory" not in session_info:
- session_info["workspace_directory"] = ""
-
- with open(getSessionsListFilePath(), "w") as f:
- json.dump(sessions_list, f)
-
- async def load_session(
- self, old_session_id: str, new_session_id: Optional[str] = None
- ) -> str:
- """Load the session's FullState from a json file"""
-
- # First persist the current state
- await self.persist_session(old_session_id)
-
- # Delete the old session, but keep the IDE
- ide = self.registered_ides[old_session_id]
- del self.registered_ides[old_session_id]
-
- # Start the new session
- new_session = await self.new_session(ide, session_id=new_session_id)
- return new_session.session_id
-
- def register_websocket(self, session_id: str, ws: WebSocket):
- self.sessions[session_id].ws = ws
- logger.debug(f"Registered websocket for session {session_id}")
-
- async def send_ws_data(self, session_id: str, message_type: str, data: Any):
- if session_id not in self.sessions:
- logger.warning(f"Session {session_id} not found")
- return
- if self.sessions[session_id].ws is None:
- return
-
- await self.sessions[session_id].ws.send_json(
- {"messageType": message_type, "data": data}
- )
-
-
-session_manager = SessionManager()
-
-
-@router.get("/list")
-async def list_sessions():
- """List all sessions"""
- sessions_list_file = getSessionsListFilePath()
- if not os.path.exists(sessions_list_file):
- print("Returning empty sessions list")
- return []
- sessions = json.load(open(sessions_list_file, "r"))
- print("Returning sessions list: ", sessions)
- return sessions