summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py2
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py75
-rw-r--r--continuedev/src/continuedev/steps/chat.py19
3 files changed, 63 insertions, 33 deletions
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 91b5842a..18e0e6f4 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -58,7 +58,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 13de7990..c81d8aa4 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -46,9 +46,17 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
+def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
+
+
def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_tokens(model, message.content)
+ sum(count_chat_message_tokens(model, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
@@ -74,37 +82,58 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
message.content = message.summary
i += 1
- # 4. Remove entire messages in the last 5
- while total_tokens > max_tokens and len(chat_history) > 0:
+ # 4. Remove entire messages in the last 5, except last 1
+ while total_tokens > max_tokens and len(chat_history) > 1:
message = chat_history.pop(0)
total_tokens -= count_tokens(model, message.content)
+ # 5. Truncate last message
+ if total_tokens > max_tokens and len(chat_history) > 0:
+ message = chat_history[0]
+ message.content = prune_raw_prompt_from_top(
+ model, message.content, tokens_for_completion)
+ total_tokens = max_tokens
+
return chat_history
+# In case we've missed weird edge cases
+TOKEN_BUFFER_FOR_SAFETY = 100
+
+
def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
- prompt_tokens = count_tokens(model, prompt)
+ """
+ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
+ """
+ if prompt is not None:
+ prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt)
+ msgs += [prompt_msg]
+
+ if system_message is not None:
+ # NOTE: System message takes second precedence to user prompt, so it is placed just before
+ # but move back to start after processing
+ rendered_system_message = render_system_message(system_message)
+ system_chat_msg = ChatMessage(
+ role="system", content=rendered_system_message, summary=rendered_system_message)
+ # insert at second-to-last position
+ msgs.insert(-1, system_chat_msg)
+
+ # Add tokens from functions
+ function_tokens = 0
if functions is not None:
for function in functions:
- prompt_tokens += count_tokens(model, json.dumps(function))
-
- rendered_system_message = render_system_message(system_message)
-
- msgs = prune_chat_history(model,
- msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message))
- history = []
- if system_message:
- history.append({
- "role": "system",
- "content": rendered_system_message
- })
- history += [msg.to_dict(with_functions=functions is not None)
- for msg in msgs]
- if prompt:
- history.append({
- "role": "user",
- "content": prompt
- })
+ function_tokens += count_tokens(model, json.dumps(function))
+
+ msgs = prune_chat_history(
+ model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+
+ history = [msg.to_dict(with_functions=functions is not None)
+ for msg in msgs]
+
+ # Move system message back to start
+ if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system":
+ system_message_dict = history.pop(-2)
+ history.insert(0, system_message_dict)
return history
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 7c6b42db..8c03969e 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -35,23 +35,24 @@ class SimpleChatStep(Step):
if sdk.current_step_was_deleted():
# So that the message doesn't disappear
self.hide = False
- return
+ break
if "content" in chunk:
self.description += chunk["content"]
completion += chunk["content"]
await sdk.update_ui()
finally:
- await generator.aclose()
+ self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ f"Write a short title for the following chat message: {self.description}"))
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
- f"Write a short title for the following chat message: {self.description}"))
+ self.chat_context.append(ChatMessage(
+ role="assistant",
+ content=completion,
+ summary=self.name
+ ))
- self.chat_context.append(ChatMessage(
- role="assistant",
- content=completion,
- summary=self.name
- ))
+ # TODO: Never actually closing.
+ await generator.aclose()
class AddFileStep(Step):