diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-15 10:46:24 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-15 10:46:24 -0700 |
commit | c56e24d2a5f2b40702e4b495fa3f28d554eaa3ab (patch) | |
tree | e28a5f159654402d32c3244f069cdc0a82e2d1b0 | |
parent | ffc242ef2ab4ef16306afb237a781937a6c6d52c (diff) | |
download | sncontinue-c56e24d2a5f2b40702e4b495fa3f28d554eaa3ab.tar.gz sncontinue-c56e24d2a5f2b40702e4b495fa3f28d554eaa3ab.tar.bz2 sncontinue-c56e24d2a5f2b40702e4b495fa3f28d554eaa3ab.zip |
fix: :bug: fixes to templating messages
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_tgi.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/text_gen_interface.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 8 |
3 files changed, 13 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index 8d16198d..a3672fe2 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -6,7 +6,7 @@ from pydantic import Field from ...core.main import ChatMessage from ..llm import LLM, CompletionOptions -from .prompts.chat import code_llama_template_messages +from .prompts.chat import llama2_template_messages from .prompts.edit import simplified_edit_prompt @@ -16,7 +16,7 @@ class HuggingFaceTGI(LLM): "http://localhost:8080", description="URL of your TGI server" ) - template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages prompt_templates = { "edit": simplified_edit_prompt, diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py index 28b2bfae..1ff9feb7 100644 --- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py +++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py @@ -1,11 +1,12 @@ import json -from typing import Any, List, Optional +from typing import Any, Callable, Dict, List, Union import websockets from pydantic import Field from ...core.main import ChatMessage from . import LLM +from .prompts.chat import llama2_template_messages from .prompts.edit import simplest_edit_prompt @@ -40,6 +41,10 @@ class TextGenUI(LLM): "edit": simplest_edit_prompt, } + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages + class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 4def3198..d895a2cf 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -23,6 +23,9 @@ already_saw_import_err = False def encoding_for_model(model_name: str): global already_saw_import_err + if already_saw_import_err: + return None + try: import tiktoken from tiktoken_ext import openai_public # noqa: F401 @@ -32,9 +35,8 @@ def encoding_for_model(model_name: str): except Exception as _: return tiktoken.encoding_for_model("gpt-3.5-turbo") except Exception as e: - if not already_saw_import_err: - print("Error importing tiktoken", e) - already_saw_import_err = True + print("Error importing tiktoken", e) + already_saw_import_err = True return None |