diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index d0a6bbdc..2663aa1c 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,9 +1,6 @@ import json from typing import Dict, List, Union -import tiktoken -from tiktoken_ext import openai_public # noqa: F401 - from ...core.main import ChatMessage from .templating import render_templated_string @@ -27,15 +24,25 @@ DEFAULT_ARGS = { def encoding_for_model(model_name: str): try: - return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) - except: - return tiktoken.encoding_for_model("gpt-3.5-turbo") + import tiktoken + from tiktoken_ext import openai_public # noqa: F401 + + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except: + return tiktoken.encoding_for_model("gpt-3.5-turbo") + except Exception as e: + print("Error importing tiktoken", e) + return None def count_tokens(model_name: str, text: Union[str, None]): if text is None: return 0 encoding = encoding_for_model(model_name) + if encoding is None: + # Make a safe estimate given that tokens are usually typically ~4 characters on average + return len(text) // 2 return len(encoding.encode(text, disallowed_special=())) @@ -52,6 +59,11 @@ def prune_raw_prompt_from_top( ): max_tokens = context_length - tokens_for_completion encoding = encoding_for_model(model_name) + + if encoding is None: + desired_length_in_chars = max_tokens * 2 + return prompt[-desired_length_in_chars:] + tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: return prompt |