diff options
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 24 | 
1 files changed, 18 insertions, 6 deletions
| diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index d0a6bbdc..2663aa1c 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,9 +1,6 @@  import json  from typing import Dict, List, Union -import tiktoken -from tiktoken_ext import openai_public  # noqa: F401 -  from ...core.main import ChatMessage  from .templating import render_templated_string @@ -27,15 +24,25 @@ DEFAULT_ARGS = {  def encoding_for_model(model_name: str):      try: -        return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) -    except: -        return tiktoken.encoding_for_model("gpt-3.5-turbo") +        import tiktoken +        from tiktoken_ext import openai_public  # noqa: F401 + +        try: +            return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) +        except: +            return tiktoken.encoding_for_model("gpt-3.5-turbo") +    except Exception as e: +        print("Error importing tiktoken", e) +        return None  def count_tokens(model_name: str, text: Union[str, None]):      if text is None:          return 0      encoding = encoding_for_model(model_name) +    if encoding is None: +        # Make a safe estimate given that tokens are usually typically ~4 characters on average +        return len(text) // 2      return len(encoding.encode(text, disallowed_special=())) @@ -52,6 +59,11 @@ def prune_raw_prompt_from_top(  ):      max_tokens = context_length - tokens_for_completion      encoding = encoding_for_model(model_name) + +    if encoding is None: +        desired_length_in_chars = max_tokens * 2 +        return prompt[-desired_length_in_chars:] +      tokens = encoding.encode(prompt, disallowed_special=())      if len(tokens) <= max_tokens:          return prompt | 
