diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 21 |
2 files changed, 19 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index f4285bc9..c77eb2e3 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -98,7 +98,9 @@ def display_llm_class(llm, new: bool = False): [ f"{k}={display_val(v)}" for k, v in llm.dict().items() - if k not in filtered_attrs and v is not None + if k not in filtered_attrs + and v is not None + and not v == llm.__fields__[k].default ] ) return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 10f6974f..cc6bc911 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,7 @@ from uvicorn.main import Server from ..core.main import ContextItem from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES +from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages from ..libs.util.create_async_task import create_async_task from ..libs.util.edit_config import ( add_config_import, @@ -323,7 +324,22 @@ class GUIProtocolServer: existing_saved_models.add(display_llm_class(val)) models.__setattr__(role, None) + # Add the requisite import to config.py + add_config_import( + f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" + ) + if "template_messages" in model: + add_config_import( + f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}" + ) + # Set and start the new default model + + if "template_messages" in model: + model["template_messages"] = { + "llama2_template_messages": llama2_template_messages, + "template_alpaca_messages": template_alpaca_messages, + }[model["template_messages"]] new_model = MODEL_CLASSES[model_class](**model) models.default = new_model await self.session.autopilot.continue_sdk.start_model(models.default) @@ -343,11 +359,6 @@ class GUIProtocolServer: create_obj_node("Models", models_args), ) - # Add the requisite import to config.py - add_config_import( - f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" - ) - # Set all roles (in-memory) to the new default model for role in ALL_MODEL_ROLES: if role != "default": |