summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-27 17:05:27 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-27 17:05:27 -0700
commit7006dbb3e38a837a2580a516791874f6815ac25f (patch)
treec4137ab472f4262255b2211bcdda0c1786280eb5 /continuedev
parent774e2e683808de858b8153c0341ab4dff8358722 (diff)
downloadsncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.tar.gz
sncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.tar.bz2
sncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.zip
fix: :bug: default to counting chars if tiktoken blocked
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index d0a6bbdc..2663aa1c 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,9 +1,6 @@
import json
from typing import Dict, List, Union
-import tiktoken
-from tiktoken_ext import openai_public # noqa: F401
-
from ...core.main import ChatMessage
from .templating import render_templated_string
@@ -27,15 +24,25 @@ DEFAULT_ARGS = {
def encoding_for_model(model_name: str):
try:
- return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
- except:
- return tiktoken.encoding_for_model("gpt-3.5-turbo")
+ import tiktoken
+ from tiktoken_ext import openai_public # noqa: F401
+
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
+ except Exception as e:
+ print("Error importing tiktoken", e)
+ return None
def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
encoding = encoding_for_model(model_name)
+ if encoding is None:
+ # Make a safe estimate given that tokens are usually typically ~4 characters on average
+ return len(text) // 2
return len(encoding.encode(text, disallowed_special=()))
@@ -52,6 +59,11 @@ def prune_raw_prompt_from_top(
):
max_tokens = context_length - tokens_for_completion
encoding = encoding_for_model(model_name)
+
+ if encoding is None:
+ desired_length_in_chars = max_tokens * 2
+ return prompt[-desired_length_in_chars:]
+
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt