diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-27 17:05:27 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-27 17:05:27 -0700 |
commit | 7006dbb3e38a837a2580a516791874f6815ac25f (patch) | |
tree | c4137ab472f4262255b2211bcdda0c1786280eb5 /continuedev | |
parent | 774e2e683808de858b8153c0341ab4dff8358722 (diff) | |
download | sncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.tar.gz sncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.tar.bz2 sncontinue-7006dbb3e38a837a2580a516791874f6815ac25f.zip |
fix: :bug: default to counting chars if tiktoken blocked
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index d0a6bbdc..2663aa1c 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,9 +1,6 @@ import json from typing import Dict, List, Union -import tiktoken -from tiktoken_ext import openai_public # noqa: F401 - from ...core.main import ChatMessage from .templating import render_templated_string @@ -27,15 +24,25 @@ DEFAULT_ARGS = { def encoding_for_model(model_name: str): try: - return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) - except: - return tiktoken.encoding_for_model("gpt-3.5-turbo") + import tiktoken + from tiktoken_ext import openai_public # noqa: F401 + + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except: + return tiktoken.encoding_for_model("gpt-3.5-turbo") + except Exception as e: + print("Error importing tiktoken", e) + return None def count_tokens(model_name: str, text: Union[str, None]): if text is None: return 0 encoding = encoding_for_model(model_name) + if encoding is None: + # Make a safe estimate given that tokens are usually typically ~4 characters on average + return len(text) // 2 return len(encoding.encode(text, disallowed_special=())) @@ -52,6 +59,11 @@ def prune_raw_prompt_from_top( ): max_tokens = context_length - tokens_for_completion encoding = encoding_for_model(model_name) + + if encoding is None: + desired_length_in_chars = max_tokens * 2 + return prompt[-desired_length_in_chars:] + tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: return prompt |