summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/util
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/libs/util')
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 154af5e1..047a47e4 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,3 +1,4 @@
+import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
import tiktoken
@@ -5,10 +6,10 @@ import tiktoken
aliases = {}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS,
- "gpt-3.5-turbo-0613": 4096 - DEFAULT_MAX_TOKENS,
- "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS,
- "gpt-4": 8192 - DEFAULT_MAX_TOKENS
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
@@ -28,8 +29,9 @@ def count_tokens(model: str, text: Union[str, None]):
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS)
+def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
+ max_tokens = MAX_TOKENS_FOR_MODEL.get(
+ model, DEFAULT_MAX_TOKENS) - tokens_for_completion
encoding = encoding_for_model(model)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
@@ -59,8 +61,8 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
# 3. Truncate message in the last 5
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0:
- message = chat_history[0]
+ while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history):
+ message = chat_history[i]
total_tokens -= count_tokens(model, message.content)
total_tokens += count_tokens(model, message.summary)
message.content = message.summary
@@ -74,8 +76,12 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
return chat_history
-def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, with_functions: bool = False, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
prompt_tokens = count_tokens(model, prompt)
+ if functions is not None:
+ for function in functions:
+ prompt_tokens += count_tokens(model, json.dumps(function))
+
msgs = prune_chat_history(model,
msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + 1000 + count_tokens(model, system_message))
history = []
@@ -84,7 +90,8 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str
"role": "system",
"content": system_message
})
- history += [msg.to_dict(with_functions=with_functions) for msg in msgs]
+ history += [msg.to_dict(with_functions=functions is not None)
+ for msg in msgs]
if prompt:
history.append({
"role": "user",