summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-29 20:20:45 -0700
committerGitHub <noreply@github.com>2023-09-29 20:20:45 -0700
commit0dfdd4c52a9d686af54346ade35e0bcff226c8b9 (patch)
treed4f98c7809ddfc7ed14e3be36fe921cc418a8917 /continuedev
parent64558321addcc80de9137cf9c9ef1bf7ed85ffa5 (diff)
downloadsncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.gz
sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.bz2
sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.zip
Model config UI (#522)
* feat: :sparkles: improved model selection * feat: :sparkles: add max_tokens option to LLM class * docs: :memo: update reference with max_tokens * feat: :loud_sound: add context to dev data loggign * feat: :sparkles: final work on model config ui
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py21
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py10
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py4
-rw-r--r--continuedev/src/continuedev/server/gui.py21
4 files changed, 45 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 9ebf288b..9f2338ff 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -2,6 +2,7 @@ import json
import os
import time
import traceback
+import uuid
from functools import cached_property
from typing import Callable, Coroutine, Dict, List, Optional
@@ -380,11 +381,27 @@ class Autopilot(ContinueBaseModel):
# last_depth = self.history.timeline[i].depth
# i -= 1
+ # Log the context and step to dev data
+ context_used = await self.context_manager.get_selected_items()
posthog_logger.capture_event(
"step run", {"step_name": step.name, "params": step.dict()}
)
+ step_id = uuid.uuid4().hex
dev_data_logger.capture(
- "step_run", {"step_name": step.name, "params": step.dict()}
+ "step_run",
+ {"step_name": step.name, "params": step.dict(), "step_id": step_id},
+ )
+ dev_data_logger.capture(
+ "context_used",
+ {
+ "context": list(
+ map(
+ lambda item: item.dict(),
+ context_used,
+ )
+ ),
+ "step_id": step_id,
+ },
)
if not is_future_step:
@@ -402,7 +419,7 @@ class Autopilot(ContinueBaseModel):
step=step,
observation=None,
depth=self._step_depth,
- context_used=await self.context_manager.get_selected_items(),
+ context_used=context_used,
)
)
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 28f614c7..e6a90ef7 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -71,6 +71,10 @@ class LLM(ContinueBaseModel):
..., description="The name of the model to be used (e.g. gpt-4, codellama)"
)
+ max_tokens: int = Field(
+ DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate."
+ )
+
stop_tokens: Optional[List[str]] = Field(
None, description="Tokens that will stop the completion."
)
@@ -237,7 +241,7 @@ class LLM(ContinueBaseModel):
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
stop=stop or self.stop_tokens,
- max_tokens=max_tokens,
+ max_tokens=max_tokens or self.max_tokens,
functions=functions,
)
@@ -288,7 +292,7 @@ class LLM(ContinueBaseModel):
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
stop=stop or self.stop_tokens,
- max_tokens=max_tokens,
+ max_tokens=max_tokens or self.max_tokens,
functions=functions,
)
@@ -337,7 +341,7 @@ class LLM(ContinueBaseModel):
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
stop=stop or self.stop_tokens,
- max_tokens=max_tokens,
+ max_tokens=max_tokens or self.max_tokens,
functions=functions,
)
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index f4285bc9..c77eb2e3 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -98,7 +98,9 @@ def display_llm_class(llm, new: bool = False):
[
f"{k}={display_val(v)}"
for k, v in llm.dict().items()
- if k not in filtered_attrs and v is not None
+ if k not in filtered_attrs
+ and v is not None
+ and not v == llm.__fields__[k].default
]
)
return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)"
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 10f6974f..cc6bc911 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -10,6 +10,7 @@ from uvicorn.main import Server
from ..core.main import ContextItem
from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
+from ..libs.llm.prompts.chat import llama2_template_messages, template_alpaca_messages
from ..libs.util.create_async_task import create_async_task
from ..libs.util.edit_config import (
add_config_import,
@@ -323,7 +324,22 @@ class GUIProtocolServer:
existing_saved_models.add(display_llm_class(val))
models.__setattr__(role, None)
+ # Add the requisite import to config.py
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
+ )
+ if "template_messages" in model:
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.prompts.chat import {model['template_messages']}"
+ )
+
# Set and start the new default model
+
+ if "template_messages" in model:
+ model["template_messages"] = {
+ "llama2_template_messages": llama2_template_messages,
+ "template_alpaca_messages": template_alpaca_messages,
+ }[model["template_messages"]]
new_model = MODEL_CLASSES[model_class](**model)
models.default = new_model
await self.session.autopilot.continue_sdk.start_model(models.default)
@@ -343,11 +359,6 @@ class GUIProtocolServer:
create_obj_node("Models", models_args),
)
- # Add the requisite import to config.py
- add_config_import(
- f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
- )
-
# Set all roles (in-memory) to the new default model
for role in ALL_MODEL_ROLES:
if role != "default":