diff options
Diffstat (limited to 'continuedev/src/continuedev/libs/util')
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 154af5e1..047a47e4 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Union from ...core.main import ChatMessage import tiktoken @@ -5,10 +6,10 @@ import tiktoken aliases = {} DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, - "gpt-3.5-turbo-0613": 4096 - DEFAULT_MAX_TOKENS, - "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, - "gpt-4": 8192 - DEFAULT_MAX_TOKENS + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" @@ -28,8 +29,9 @@ def count_tokens(model: str, text: Union[str, None]): return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str): - max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS) +def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): + max_tokens = MAX_TOKENS_FOR_MODEL.get( + model, DEFAULT_MAX_TOKENS) - tokens_for_completion encoding = encoding_for_model(model) tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: @@ -59,8 +61,8 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: # 3. Truncate message in the last 5 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0: - message = chat_history[0] + while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history): + message = chat_history[i] total_tokens -= count_tokens(model, message.content) total_tokens += count_tokens(model, message.summary) message.content = message.summary @@ -74,8 +76,12 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, with_functions: bool = False, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: prompt_tokens = count_tokens(model, prompt) + if functions is not None: + for function in functions: + prompt_tokens += count_tokens(model, json.dumps(function)) + msgs = prune_chat_history(model, msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + 1000 + count_tokens(model, system_message)) history = [] @@ -84,7 +90,8 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str "role": "system", "content": system_message }) - history += [msg.to_dict(with_functions=with_functions) for msg in msgs] + history += [msg.to_dict(with_functions=functions is not None) + for msg in msgs] if prompt: history.append({ "role": "user", |
