diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-29 20:20:45 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-29 20:20:45 -0700 |
commit | 0dfdd4c52a9d686af54346ade35e0bcff226c8b9 (patch) | |
tree | d4f98c7809ddfc7ed14e3be36fe921cc418a8917 /continuedev | |
parent | 64558321addcc80de9137cf9c9ef1bf7ed85ffa5 (diff) | |
download | sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.gz sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.bz2 sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.zip |
Model config UI (#522)
* feat: :sparkles: improved model selection
* feat: :sparkles: add max_tokens option to LLM class
* docs: :memo: update reference with max_tokens
* feat: :loud_sound: add context to dev data loggign
* feat: :sparkles: final work on model config ui
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 21 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 21 |
4 files changed, 45 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 9ebf288b..9f2338ff 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -2,6 +2,7 @@ import json import os import time import traceback +import uuid from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional @@ -380,11 +381,27 @@ class Autopilot(ContinueBaseModel): # last_depth = self.history.timeline[i].depth # i -= 1 + # Log the context and step to dev data + context_used = await self.context_manager.get_selected_items() posthog_logger.capture_event( "step run", {"step_name": step.name, "params": step.dict()} ) + step_id = uuid.uuid4().hex dev_data_logger.capture( - "step_run", {"step_name": step.name, "params": step.dict()} + "step_run", + {"step_name": step.name, "params": step.dict(), "step_id": step_id}, + ) + dev_data_logger.capture( + "context_used", + { + "context": list( + map( + lambda item: item.dict(), + context_used, + ) + ), + "step_id": step_id, + }, ) if not is_future_step: @@ -402,7 +419,7 @@ class Autopilot(ContinueBaseModel): step=step, observation=None, depth=self._step_depth, - context_used=await self.context_manager.get_selected_items(), + context_used=context_used, ) ) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 28f614c7..e6a90ef7 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -71,6 +71,10 @@ class LLM(ContinueBaseModel): ..., description="The name of the model to be used (e.g. gpt-4, codellama)" ) + max_tokens: int = Field( + DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate." + ) + stop_tokens: Optional[List[str]] = Field( None, description="Tokens that will stop the completion." ) @@ -237,7 +241,7 @@ class LLM(ContinueBaseModel): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, stop=stop or self.stop_tokens, - max_tokens=max_tokens, + max_tokens=max_tokens or self.max_tokens, functions=functions, ) @@ -288,7 +292,7 @@ class LLM(ContinueBaseModel): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, stop=stop or self.stop_tokens, - max_tokens=max_tokens, + max_tokens=max_tokens or self.max_tokens, functions=functions, ) @@ -337,7 +341,7 @@ class LLM(ContinueBaseModel): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, stop=stop or self.stop_tokens, - max_tokens=max_tokens, + max_tokens=max_tokens or self.max_tokens, functions=functions, ) diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index f4285bc9..c77eb2e3 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -98,7 +98,9 @@ def display_llm_class(llm, new: bool = False): [ f"{k}={display_val(v)}" for k, v in llm.dict().items() - if k not in filtered_attrs and v is not None + if k not in filtered_attrs + and v is not None + and not v == llm.__fields__[k].default ] ) return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 10f6974f..cc6bc911 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,7 @@ from uvicorn.main import Server from ..core.main import ContextItem from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES +from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages from ..libs.util.create_async_task import create_async_task from ..libs.util.edit_config import ( add_config_import, @@ -323,7 +324,22 @@ class GUIProtocolServer: existing_saved_models.add(display_llm_class(val)) models.__setattr__(role, None) + # Add the requisite import to config.py + add_config_import( + f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" + ) + if "template_messages" in model: + add_config_import( + f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}" + ) + # Set and start the new default model + + if "template_messages" in model: + model["template_messages"] = { + "llama2_template_messages": llama2_template_messages, + "template_alpaca_messages": template_alpaca_messages, + }[model["template_messages"]] new_model = MODEL_CLASSES[model_class](**model) models.default = new_model await self.session.autopilot.continue_sdk.start_model(models.default) @@ -343,11 +359,6 @@ class GUIProtocolServer: create_obj_node("Models", models_args), ) - # Add the requisite import to config.py - add_config_import( - f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" - ) - # Set all roles (in-memory) to the new default model for role in ALL_MODEL_ROLES: if role != "default": |