diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-01 18:31:33 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-01 18:31:33 -0700 |
commit | 2f792f46026a6bb3c3580f2521b01ecb8c68117c (patch) | |
tree | 36cecf6d218bc166c0a8c1c78261b4feac7f01cd /continuedev/src/continuedev/server | |
parent | 5c8b28b7fddf5b214de61102c768ef44d4087870 (diff) | |
download | sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.gz sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.bz2 sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.zip |
feat: :sparkles: improved model dropdown
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 80 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 40 |
2 files changed, 40 insertions, 80 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 2c4f2e4d..49541b76 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -23,7 +23,6 @@ from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..plugins.steps.core.core import DisplayErrorStep from ..plugins.steps.setup_model import SetupModelStep -from .gui_protocol import AbstractGUIProtocolServer from .session_manager import Session, session_manager router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,7 +53,7 @@ T = TypeVar("T", bound=BaseModel) # You should probably abstract away the websocket stuff into a separate class -class GUIProtocolServer(AbstractGUIProtocolServer): +class GUIProtocolServer: websocket: WebSocket session: Session sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() @@ -118,8 +117,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.set_system_message(data["message"]) elif message_type == "set_temperature": self.set_temperature(float(data["temperature"])) - elif message_type == "set_model_for_role": - self.set_model_for_role(data["role"], data["model_class"], data["model"]) + elif message_type == "add_model_for_role": + self.add_model_for_role(data["role"], data["model_class"], data["model"]) + elif message_type == "set_model_for_role_from_index": + self.set_model_for_role_from_index(data["role"], data["index"]) elif message_type == "save_context_group": self.save_context_group( data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -230,51 +231,50 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_error, ) - def set_model_for_role(self, role: str, model_class: str, model: Any): + def set_model_for_role_from_index(self, role: str, index: int): + async def async_stuff(): + models = self.session.autopilot.continue_sdk.config.models + + # Set models in SDK + temp = models.default + models.default = models.unused[index] + models.unused[index] = temp + await self.session.autopilot.continue_sdk.start_model(models.default) + + # Set models in config.py + JOINER = ", " + models_args = { + "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]", + ("default" if role == "*" else role): display_llm_class(models.default), + } + + await self.session.autopilot.set_config_attr( + ["models"], + create_obj_node("Models", models_args), + ) + + for other_role in ALL_MODEL_ROLES: + if other_role != "default": + models.__setattr__(other_role, models.default) + + await self.session.autopilot.continue_sdk.update_ui() + + create_async_task(async_stuff(), self.on_error) + + def add_model_for_role(self, role: str, model_class: str, model: Any): models = self.session.autopilot.continue_sdk.config.models unused_models = models.unused if role == "*": async def async_stuff(): - # Clear all of the models and store them in unused_models - # NOTE: There will be duplicates for role in ALL_MODEL_ROLES: - prev_model = models.__getattribute__(role) models.__setattr__(role, None) - if prev_model is not None: - exists = False - for other in unused_models: - if ( - prev_model.__class__.__name__ - == other.__class__.__name__ - and ( - other.name is not None - and ( - not other.name.startswith("gpt") - or prev_model.name == other.name - ) - ) - ): - exists = True - break - if not exists: - unused_models.append(prev_model) - - # Replace default with either new one or existing from unused_models - for unused_model in unused_models: - if model_class == unused_model.__class__.__name__ and ( - "model" not in model - or model["model"] == unused_model.model - and model["model"].startswith("gpt") - ): - models.default = unused_model # Set and start the default model if didn't already exist from unused - if models.default is None: - models.default = MODEL_CLASSES[model_class](**model) - await self.session.autopilot.continue_sdk.run_step( - SetupModelStep(model_class=model_class) - ) + models.default = MODEL_CLASSES[model_class](**model) + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) await self.session.autopilot.continue_sdk.start_model(models.default) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py deleted file mode 100644 index d079475c..00000000 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class AbstractGUIProtocolServer(ABC): - @abstractmethod - async def handle_json(self, data: Any): - """Handle a json message""" - - @abstractmethod - def on_main_input(self, input: str): - """Called when the user inputs something""" - - @abstractmethod - def on_reverse_to_index(self, index: int): - """Called when the user requests reverse to a previous index""" - - @abstractmethod - def on_refinement_input(self, input: str, index: int): - """Called when the user inputs a refinement""" - - @abstractmethod - def on_step_user_input(self, input: str, index: int): - """Called when the user inputs a step""" - - @abstractmethod - def on_retry_at_index(self, index: int): - """Called when the user requests a retry at a previous index""" - - @abstractmethod - def on_clear_history(self): - """Called when the user requests to clear the history""" - - @abstractmethod - def on_delete_at_index(self, index: int): - """Called when the user requests to delete a step at a given index""" - - @abstractmethod - def select_context_item(self, id: str, query: str): - """Called when user selects an item from the dropdown""" |