diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 |
commit | 044b7caa6b26a5d78ae52faa0ae675abc8c4e161 (patch) | |
tree | 5727da03069cc6bf83b9dd385003f615124e6fe8 /continuedev/src/continuedev/server | |
parent | 631d141dbd26edb0de3e0e3ed194dbfd3641059f (diff) | |
download | sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.gz sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.bz2 sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.zip |
feat: :sparkles: select model from dropdown
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 83 |
1 files changed, 73 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index bdcaad47..55a5f3b4 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server from ..core.main import ContextItem +from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES from ..libs.util.create_async_task import create_async_task from ..libs.util.edit_config import ( + add_config_import, create_float_node, create_obj_node, create_string_node, + display_llm_class, ) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.setup_model import SetupModelStep from .gui_protocol import AbstractGUIProtocolServer from .session_manager import Session, session_manager @@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer): ) def set_model_for_role(self, role: str, model_class: str, model: Any): - prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) - if prev_model is not None: - prev_model.update(model) - self.session.autopilot.continue_sdk.models.__setattr__(role, model) - create_async_task( - self.session.autopilot.set_config_attr( - ["models", role], create_obj_node(model_class, {**model}) - ), - self.on_error, - ) + models = self.session.autopilot.continue_sdk.config.models + unused_models = models.unused + if role == "*": + + async def async_stuff(): + # Clear all of the models and store them in unused_models + # NOTE: There will be duplicates + for role in ALL_MODEL_ROLES: + prev_model = models.__getattribute__(role) + models.__setattr__(role, None) + if prev_model is not None: + exists = False + for other in unused_models: + if display_llm_class(prev_model) == display_llm_class( + other + ): + exists = True + break + if not exists: + unused_models.append(prev_model) + + # Replace default with either new one or existing from unused_models + for unused_model in unused_models: + if model_class == unused_model.__class__.__name__ and ( + "model" not in model or model["model"] == unused_model.model + ): + models.default = unused_model + + # Set and start the default model if didn't already exist from unused + if models.default is None: + models.default = MODEL_CLASSES[model_class](**model) + await self.session.autopilot.continue_sdk.start_model( + models.default + ) + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) + + models_args = {} + + for role in ALL_MODEL_ROLES: + val = models.__getattribute__(role) + if val is None: + continue # no pun intended + + models_args[role] = display_llm_class(val) + + JOINER = ", " + models_args[ + "unused" + ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]" + + await self.session.autopilot.set_config_attr( + ["models"], + create_obj_node("Models", models_args), + ) + + add_config_import( + f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" + ) + + for role in ALL_MODEL_ROLES: + if role != "default": + models.__setattr__(role, models.default) + + create_async_task(async_stuff(), self.on_error) + else: + # TODO + pass def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( |