summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/server
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-01 18:31:33 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-01 18:31:33 -0700
commit2f792f46026a6bb3c3580f2521b01ecb8c68117c (patch)
tree36cecf6d218bc166c0a8c1c78261b4feac7f01cd /continuedev/src/continuedev/server
parent5c8b28b7fddf5b214de61102c768ef44d4087870 (diff)
downloadsncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.gz
sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.bz2
sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.zip
feat: :sparkles: improved model dropdown
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r--continuedev/src/continuedev/server/gui.py80
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py40
2 files changed, 40 insertions, 80 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 2c4f2e4d..49541b76 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -23,7 +23,6 @@ from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
from ..plugins.steps.setup_model import SetupModelStep
-from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -54,7 +53,7 @@ T = TypeVar("T", bound=BaseModel)
# You should probably abstract away the websocket stuff into a separate class
-class GUIProtocolServer(AbstractGUIProtocolServer):
+class GUIProtocolServer:
websocket: WebSocket
session: Session
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
@@ -118,8 +117,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.set_system_message(data["message"])
elif message_type == "set_temperature":
self.set_temperature(float(data["temperature"]))
- elif message_type == "set_model_for_role":
- self.set_model_for_role(data["role"], data["model_class"], data["model"])
+ elif message_type == "add_model_for_role":
+ self.add_model_for_role(data["role"], data["model_class"], data["model"])
+ elif message_type == "set_model_for_role_from_index":
+ self.set_model_for_role_from_index(data["role"], data["index"])
elif message_type == "save_context_group":
self.save_context_group(
data["title"], [ContextItem(**item) for item in data["context_items"]]
@@ -230,51 +231,50 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_error,
)
- def set_model_for_role(self, role: str, model_class: str, model: Any):
+ def set_model_for_role_from_index(self, role: str, index: int):
+ async def async_stuff():
+ models = self.session.autopilot.continue_sdk.config.models
+
+ # Set models in SDK
+ temp = models.default
+ models.default = models.unused[index]
+ models.unused[index] = temp
+ await self.session.autopilot.continue_sdk.start_model(models.default)
+
+ # Set models in config.py
+ JOINER = ", "
+ models_args = {
+ "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]",
+ ("default" if role == "*" else role): display_llm_class(models.default),
+ }
+
+ await self.session.autopilot.set_config_attr(
+ ["models"],
+ create_obj_node("Models", models_args),
+ )
+
+ for other_role in ALL_MODEL_ROLES:
+ if other_role != "default":
+ models.__setattr__(other_role, models.default)
+
+ await self.session.autopilot.continue_sdk.update_ui()
+
+ create_async_task(async_stuff(), self.on_error)
+
+ def add_model_for_role(self, role: str, model_class: str, model: Any):
models = self.session.autopilot.continue_sdk.config.models
unused_models = models.unused
if role == "*":
async def async_stuff():
- # Clear all of the models and store them in unused_models
- # NOTE: There will be duplicates
for role in ALL_MODEL_ROLES:
- prev_model = models.__getattribute__(role)
models.__setattr__(role, None)
- if prev_model is not None:
- exists = False
- for other in unused_models:
- if (
- prev_model.__class__.__name__
- == other.__class__.__name__
- and (
- other.name is not None
- and (
- not other.name.startswith("gpt")
- or prev_model.name == other.name
- )
- )
- ):
- exists = True
- break
- if not exists:
- unused_models.append(prev_model)
-
- # Replace default with either new one or existing from unused_models
- for unused_model in unused_models:
- if model_class == unused_model.__class__.__name__ and (
- "model" not in model
- or model["model"] == unused_model.model
- and model["model"].startswith("gpt")
- ):
- models.default = unused_model
# Set and start the default model if didn't already exist from unused
- if models.default is None:
- models.default = MODEL_CLASSES[model_class](**model)
- await self.session.autopilot.continue_sdk.run_step(
- SetupModelStep(model_class=model_class)
- )
+ models.default = MODEL_CLASSES[model_class](**model)
+ await self.session.autopilot.continue_sdk.run_step(
+ SetupModelStep(model_class=model_class)
+ )
await self.session.autopilot.continue_sdk.start_model(models.default)
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
deleted file mode 100644
index d079475c..00000000
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any
-
-
-class AbstractGUIProtocolServer(ABC):
- @abstractmethod
- async def handle_json(self, data: Any):
- """Handle a json message"""
-
- @abstractmethod
- def on_main_input(self, input: str):
- """Called when the user inputs something"""
-
- @abstractmethod
- def on_reverse_to_index(self, index: int):
- """Called when the user requests reverse to a previous index"""
-
- @abstractmethod
- def on_refinement_input(self, input: str, index: int):
- """Called when the user inputs a refinement"""
-
- @abstractmethod
- def on_step_user_input(self, input: str, index: int):
- """Called when the user inputs a step"""
-
- @abstractmethod
- def on_retry_at_index(self, index: int):
- """Called when the user requests a retry at a previous index"""
-
- @abstractmethod
- def on_clear_history(self):
- """Called when the user requests to clear the history"""
-
- @abstractmethod
- def on_delete_at_index(self, index: int):
- """Called when the user requests to delete a step at a given index"""
-
- @abstractmethod
- def select_context_item(self, id: str, query: str):
- """Called when user selects an item from the dropdown"""