summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-26 10:10:06 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-26 10:10:06 -0700
commit044b7caa6b26a5d78ae52faa0ae675abc8c4e161 (patch)
tree5727da03069cc6bf83b9dd385003f615124e6fe8 /continuedev/src/continuedev/server
parent631d141dbd26edb0de3e0e3ed194dbfd3641059f (diff)
downloadsncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.gz
sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.bz2
sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.zip
feat: :sparkles: select model from dropdown
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py83
1 files changed, 73 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index bdcaad47..55a5f3b4 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
from ..core.main import ContextItem
+from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
from ..libs.util.create_async_task import create_async_task
from ..libs.util.edit_config import (
+ add_config_import,
create_float_node,
create_obj_node,
create_string_node,
+ display_llm_class,
)
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
+from ..plugins.steps.setup_model import SetupModelStep
from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager
@@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
)
def set_model_for_role(self, role: str, model_class: str, model: Any):
- prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role)
- if prev_model is not None:
- prev_model.update(model)
- self.session.autopilot.continue_sdk.models.__setattr__(role, model)
- create_async_task(
- self.session.autopilot.set_config_attr(
- ["models", role], create_obj_node(model_class, {**model})
- ),
- self.on_error,
- )
+ models = self.session.autopilot.continue_sdk.config.models
+ unused_models = models.unused
+ if role == "*":
+
+ async def async_stuff():
+ # Clear all of the models and store them in unused_models
+ # NOTE: There will be duplicates
+ for role in ALL_MODEL_ROLES:
+ prev_model = models.__getattribute__(role)
+ models.__setattr__(role, None)
+ if prev_model is not None:
+ exists = False
+ for other in unused_models:
+ if display_llm_class(prev_model) == display_llm_class(
+ other
+ ):
+ exists = True
+ break
+ if not exists:
+ unused_models.append(prev_model)
+
+ # Replace default with either new one or existing from unused_models
+ for unused_model in unused_models:
+ if model_class == unused_model.__class__.__name__ and (
+ "model" not in model or model["model"] == unused_model.model
+ ):
+ models.default = unused_model
+
+ # Set and start the default model if didn't already exist from unused
+ if models.default is None:
+ models.default = MODEL_CLASSES[model_class](**model)
+ await self.session.autopilot.continue_sdk.start_model(
+ models.default
+ )
+ await self.session.autopilot.continue_sdk.run_step(
+ SetupModelStep(model_class=model_class)
+ )
+
+ models_args = {}
+
+ for role in ALL_MODEL_ROLES:
+ val = models.__getattribute__(role)
+ if val is None:
+ continue # no pun intended
+
+ models_args[role] = display_llm_class(val)
+
+ JOINER = ", "
+ models_args[
+ "unused"
+ ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]"
+
+ await self.session.autopilot.set_config_attr(
+ ["models"],
+ create_obj_node("Models", models_args),
+ )
+
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
+ )
+
+ for role in ALL_MODEL_ROLES:
+ if role != "default":
+ models.__setattr__(role, models.default)
+
+ create_async_task(async_stuff(), self.on_error)
+ else:
+ # TODO
+ pass
def save_context_group(self, title: str, context_items: List[ContextItem]):
create_async_task(