diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-08-22 13:12:58 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-22 13:12:58 -0700 |
commit | 5eec484dc79bb56dabf9a56af0dbe6bc95227d39 (patch) | |
tree | 07edc01475cfe0ba69372f537d36aa3294680b7d /continuedev/src/continuedev/server | |
parent | b6435e1e479edb1e4f049098dc8522e944317f2a (diff) | |
download | sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.tar.gz sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.tar.bz2 sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.zip |
Config UI (#399)
* feat: :sparkles: UI for config!
* feat: :sparkles: (latent) edit models in settings
Diffstat (limited to 'continuedev/src/continuedev/server')
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 5589284a..bdcaad47 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,11 @@ from uvicorn.main import Server from ..core.main import ContextItem from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import ( + create_float_node, + create_obj_node, + create_string_node, +) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger @@ -105,6 +110,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": self.edit_step_at_index(data.get("user_input", ""), data["index"]) + elif message_type == "set_system_message": + self.set_system_message(data["message"]) + elif message_type == "set_temperature": + self.set_temperature(float(data["temperature"])) + elif message_type == "set_model_for_role": + self.set_model_for_role(data["role"], data["model_class"], data["model"]) elif message_type == "save_context_group": self.save_context_group( data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -195,6 +206,38 @@ class GUIProtocolServer(AbstractGUIProtocolServer): posthog_logger.capture_event("load_session", {"session_id": session_id}) + def set_system_message(self, message: str): + self.session.autopilot.continue_sdk.config.system_message = message + self.session.autopilot.continue_sdk.models.set_system_message(message) + + create_async_task( + self.session.autopilot.set_config_attr( + ["system_message"], create_string_node(message) + ), + self.on_error, + ) + + def set_temperature(self, temperature: float): + self.session.autopilot.continue_sdk.config.temperature = temperature + create_async_task( + self.session.autopilot.set_config_attr( + ["temperature"], create_float_node(temperature) + ), + self.on_error, + ) + + def set_model_for_role(self, role: str, model_class: str, model: Any): + prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) + if prev_model is not None: + prev_model.update(model) + self.session.autopilot.continue_sdk.models.__setattr__(role, model) + create_async_task( + self.session.autopilot.set_config_attr( + ["models", role], create_obj_node(model_class, {**model}) + ), + self.on_error, + ) + def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( self.session.autopilot.save_context_group(title, context_items), |