summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/text_gen_interface.py7
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py8
3 files changed, 13 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 8d16198d..a3672fe2 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -6,7 +6,7 @@ from pydantic import Field
from ...core.main import ChatMessage
from ..llm import LLM, CompletionOptions
-from .prompts.chat import code_llama_template_messages
+from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt
@@ -16,7 +16,7 @@ class HuggingFaceTGI(LLM):
"http://localhost:8080", description="URL of your TGI server"
)
- template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+ template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
prompt_templates = {
"edit": simplified_edit_prompt,
diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
index 28b2bfae..1ff9feb7 100644
--- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py
+++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
@@ -1,11 +1,12 @@
import json
-from typing import Any, List, Optional
+from typing import Any, Callable, Dict, List, Union
import websockets
from pydantic import Field
from ...core.main import ChatMessage
from . import LLM
+from .prompts.chat import llama2_template_messages
from .prompts.edit import simplest_edit_prompt
@@ -40,6 +41,10 @@ class TextGenUI(LLM):
"edit": simplest_edit_prompt,
}
+ template_messages: Union[
+ Callable[[List[Dict[str, str]]], str], None
+ ] = llama2_template_messages
+
class Config:
arbitrary_types_allowed = True
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 4def3198..d895a2cf 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -23,6 +23,9 @@ already_saw_import_err = False
def encoding_for_model(model_name: str):
global already_saw_import_err
+ if already_saw_import_err:
+ return None
+
try:
import tiktoken
from tiktoken_ext import openai_public # noqa: F401
@@ -32,9 +35,8 @@ def encoding_for_model(model_name: str):
except Exception as _:
return tiktoken.encoding_for_model("gpt-3.5-turbo")
except Exception as e:
- if not already_saw_import_err:
- print("Error importing tiktoken", e)
- already_saw_import_err = True
+ print("Error importing tiktoken", e)
+ already_saw_import_err = True
return None