diff options
| author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-29 20:20:45 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-29 20:20:45 -0700 | 
| commit | 0dfdd4c52a9d686af54346ade35e0bcff226c8b9 (patch) | |
| tree | d4f98c7809ddfc7ed14e3be36fe921cc418a8917 /continuedev | |
| parent | 64558321addcc80de9137cf9c9ef1bf7ed85ffa5 (diff) | |
| download | sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.gz sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.bz2 sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.zip  | |
Model config UI (#522)
* feat: :sparkles: improved model selection
* feat: :sparkles: add max_tokens option to LLM class
* docs: :memo: update reference with max_tokens
* feat: :loud_sound: add context to dev data loggign
* feat: :sparkles: final work on model config ui
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 21 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 21 | 
4 files changed, 45 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 9ebf288b..9f2338ff 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -2,6 +2,7 @@ import json  import os  import time  import traceback +import uuid  from functools import cached_property  from typing import Callable, Coroutine, Dict, List, Optional @@ -380,11 +381,27 @@ class Autopilot(ContinueBaseModel):          #     last_depth = self.history.timeline[i].depth          #     i -= 1 +        # Log the context and step to dev data +        context_used = await self.context_manager.get_selected_items()          posthog_logger.capture_event(              "step run", {"step_name": step.name, "params": step.dict()}          ) +        step_id = uuid.uuid4().hex          dev_data_logger.capture( -            "step_run", {"step_name": step.name, "params": step.dict()} +            "step_run", +            {"step_name": step.name, "params": step.dict(), "step_id": step_id}, +        ) +        dev_data_logger.capture( +            "context_used", +            { +                "context": list( +                    map( +                        lambda item: item.dict(), +                        context_used, +                    ) +                ), +                "step_id": step_id, +            },          )          if not is_future_step: @@ -402,7 +419,7 @@ class Autopilot(ContinueBaseModel):                  step=step,                  observation=None,                  depth=self._step_depth, -                context_used=await self.context_manager.get_selected_items(), +                context_used=context_used,              )          ) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 28f614c7..e6a90ef7 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -71,6 +71,10 @@ class LLM(ContinueBaseModel):          ..., description="The name of the model to be used (e.g. gpt-4, codellama)"      ) +    max_tokens: int = Field( +        DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate." +    ) +      stop_tokens: Optional[List[str]] = Field(          None, description="Tokens that will stop the completion."      ) @@ -237,7 +241,7 @@ class LLM(ContinueBaseModel):              presence_penalty=presence_penalty,              frequency_penalty=frequency_penalty,              stop=stop or self.stop_tokens, -            max_tokens=max_tokens, +            max_tokens=max_tokens or self.max_tokens,              functions=functions,          ) @@ -288,7 +292,7 @@ class LLM(ContinueBaseModel):              presence_penalty=presence_penalty,              frequency_penalty=frequency_penalty,              stop=stop or self.stop_tokens, -            max_tokens=max_tokens, +            max_tokens=max_tokens or self.max_tokens,              functions=functions,          ) @@ -337,7 +341,7 @@ class LLM(ContinueBaseModel):              presence_penalty=presence_penalty,              frequency_penalty=frequency_penalty,              stop=stop or self.stop_tokens, -            max_tokens=max_tokens, +            max_tokens=max_tokens or self.max_tokens,              functions=functions,          ) diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index f4285bc9..c77eb2e3 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -98,7 +98,9 @@ def display_llm_class(llm, new: bool = False):          [              f"{k}={display_val(v)}"              for k, v in llm.dict().items() -            if k not in filtered_attrs and v is not None +            if k not in filtered_attrs +            and v is not None +            and not v == llm.__fields__[k].default          ]      )      return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 10f6974f..cc6bc911 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,7 @@ from uvicorn.main import Server  from ..core.main import ContextItem  from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES +from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages  from ..libs.util.create_async_task import create_async_task  from ..libs.util.edit_config import (      add_config_import, @@ -323,7 +324,22 @@ class GUIProtocolServer:                          existing_saved_models.add(display_llm_class(val))                      models.__setattr__(role, None) +                # Add the requisite import to config.py +                add_config_import( +                    f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" +                ) +                if "template_messages" in model: +                    add_config_import( +                        f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}" +                    ) +                  # Set and start the new default model + +                if "template_messages" in model: +                    model["template_messages"] = { +                        "llama2_template_messages": llama2_template_messages, +                        "template_alpaca_messages": template_alpaca_messages, +                    }[model["template_messages"]]                  new_model = MODEL_CLASSES[model_class](**model)                  models.default = new_model                  await self.session.autopilot.continue_sdk.start_model(models.default) @@ -343,11 +359,6 @@ class GUIProtocolServer:                      create_obj_node("Models", models_args),                  ) -                # Add the requisite import to config.py -                add_config_import( -                    f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" -                ) -                  # Set all roles (in-memory) to the new default model                  for role in ALL_MODEL_ROLES:                      if role != "default":  | 
