summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/util/count_tokens.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/util/count_tokens.py')
-rw-r--r--server/continuedev/libs/util/count_tokens.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/server/continuedev/libs/util/count_tokens.py b/server/continuedev/libs/util/count_tokens.py
new file mode 100644
index 00000000..d895a2cf
--- /dev/null
+++ b/server/continuedev/libs/util/count_tokens.py
@@ -0,0 +1,206 @@
+import json
+from typing import Dict, List, Union
+
+from ...core.main import ChatMessage
+from .templating import render_templated_string
+
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
+aliases = {
+ "ggml": "gpt-3.5-turbo",
+ "claude-2": "gpt-3.5-turbo",
+}
+DEFAULT_MAX_TOKENS = 1024
+DEFAULT_ARGS = {
+ "max_tokens": DEFAULT_MAX_TOKENS,
+ "temperature": 0.5,
+}
+
+already_saw_import_err = False
+
+
+def encoding_for_model(model_name: str):
+ global already_saw_import_err
+ if already_saw_import_err:
+ return None
+
+ try:
+ import tiktoken
+ from tiktoken_ext import openai_public # noqa: F401
+
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except Exception as _:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
+ except Exception as e:
+ print("Error importing tiktoken", e)
+ already_saw_import_err = True
+ return None
+
+
+def count_tokens(model_name: str, text: Union[str, None]):
+ if text is None:
+ return 0
+ encoding = encoding_for_model(model_name)
+ if encoding is None:
+ # Make a safe estimate given that tokens are usually typically ~4 characters on average
+ return len(text) // 2
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(
+ model_name: str, context_length: int, prompt: str, tokens_for_completion: int
+):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
+
+ if encoding is None:
+ desired_length_in_chars = max_tokens * 2
+ return prompt[-desired_length_in_chars:]
+
+ tokens = encoding.encode(prompt, disallowed_special=())
+ if len(tokens) <= max_tokens:
+ return prompt
+ else:
+ return encoding.decode(tokens[-max_tokens:])
+
+
+def prune_chat_history(
+ model_name: str,
+ chat_history: List[ChatMessage],
+ context_length: int,
+ tokens_for_completion: int,
+):
+ total_tokens = tokens_for_completion + sum(
+ count_chat_message_tokens(model_name, message) for message in chat_history
+ )
+
+ # 1. Replace beyond last 5 messages with summary
+ i = 0
+ while total_tokens > context_length and i < len(chat_history) - 5:
+ message = chat_history[0]
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
+ message.content = message.summary
+ i += 1
+
+ # 2. Remove entire messages until the last 5
+ while (
+ len(chat_history) > 5
+ and total_tokens > context_length
+ and len(chat_history) > 0
+ ):
+ message = chat_history.pop(0)
+ total_tokens -= count_tokens(model_name, message.content)
+
+ # 3. Truncate message in the last 5, except last 1
+ i = 0
+ while (
+ total_tokens > context_length
+ and len(chat_history) > 0
+ and i < len(chat_history) - 1
+ ):
+ message = chat_history[i]
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
+ message.content = message.summary
+ i += 1
+
+ # 4. Remove entire messages in the last 5, except last 1
+ while total_tokens > context_length and len(chat_history) > 1:
+ message = chat_history.pop(0)
+ total_tokens -= count_tokens(model_name, message.content)
+
+ # 5. Truncate last message
+ if total_tokens > context_length and len(chat_history) > 0:
+ message = chat_history[0]
+ message.content = prune_raw_prompt_from_top(
+ model_name, context_length, message.content, tokens_for_completion
+ )
+ total_tokens = context_length
+
+ return chat_history
+
+
+# In case we've missed weird edge cases
+TOKEN_BUFFER_FOR_SAFETY = 100
+
+
+def compile_chat_messages(
+ model_name: str,
+ msgs: Union[List[ChatMessage], None],
+ context_length: int,
+ max_tokens: int,
+ prompt: Union[str, None] = None,
+ functions: Union[List, None] = None,
+ system_message: Union[str, None] = None,
+) -> List[Dict]:
+ """
+ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
+ """
+
+ msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else []
+
+ if prompt is not None:
+ prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt)
+ msgs_copy += [prompt_msg]
+
+ if system_message is not None and system_message.strip() != "":
+ # NOTE: System message takes second precedence to user prompt, so it is placed just before
+ # but move back to start after processing
+ rendered_system_message = render_templated_string(system_message)
+ system_chat_msg = ChatMessage(
+ role="system",
+ content=rendered_system_message,
+ summary=rendered_system_message,
+ )
+ # insert at second-to-last position
+ msgs_copy.insert(-1, system_chat_msg)
+
+ # Add tokens from functions
+ function_tokens = 0
+ if functions is not None:
+ for function in functions:
+ function_tokens += count_tokens(model_name, json.dumps(function))
+
+ if max_tokens + function_tokens + TOKEN_BUFFER_FOR_SAFETY >= context_length:
+ raise ValueError(
+ f"max_tokens ({max_tokens}) is too close to context_length ({context_length}), which doesn't leave room for chat history. This would cause incoherent responses. Try increasing the context_length parameter of the model in your config file."
+ )
+
+ msgs_copy = prune_chat_history(
+ model_name,
+ msgs_copy,
+ context_length,
+ function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY,
+ )
+
+ history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy]
+
+ # Move system message back to start
+ if (
+ system_message is not None
+ and len(history) >= 2
+ and history[-2]["role"] == "system"
+ ):
+ system_message_dict = history.pop(-2)
+ history.insert(0, system_message_dict)
+
+ return history
+
+
+def format_chat_messages(messages: List[ChatMessage]) -> str:
+ formatted = ""
+ for msg in messages:
+ formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n"
+ return formatted