summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py4
-rw-r--r--continuedev/src/continuedev/server/gui.py21
2 files changed, 19 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index f4285bc9..c77eb2e3 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -98,7 +98,9 @@ def display_llm_class(llm, new: bool = False):
[
f"{k}={display_val(v)}"
for k, v in llm.dict().items()
- if k not in filtered_attrs and v is not None
+ if k not in filtered_attrs
+ and v is not None
+ and not v == llm.__fields__[k].default
]
)
return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)"
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 10f6974f..cc6bc911 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -10,6 +10,7 @@ from uvicorn.main import Server
from ..core.main import ContextItem
from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
+from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages
from ..libs.util.create_async_task import create_async_task
from ..libs.util.edit_config import (
add_config_import,
@@ -323,7 +324,22 @@ class GUIProtocolServer:
existing_saved_models.add(display_llm_class(val))
models.__setattr__(role, None)
+ # Add the requisite import to config.py
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
+ )
+ if "template_messages" in model:
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}"
+ )
+
# Set and start the new default model
+
+ if "template_messages" in model:
+ model["template_messages"] = {
+ "llama2_template_messages": llama2_template_messages,
+ "template_alpaca_messages": template_alpaca_messages,
+ }[model["template_messages"]]
new_model = MODEL_CLASSES[model_class](**model)
models.default = new_model
await self.session.autopilot.continue_sdk.start_model(models.default)
@@ -343,11 +359,6 @@ class GUIProtocolServer:
create_obj_node("Models", models_args),
)
- # Add the requisite import to config.py
- add_config_import(
- f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
- )
-
# Set all roles (in-memory) to the new default model
for role in ALL_MODEL_ROLES:
if role != "default":