diff options
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 75 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 19 |
3 files changed, 63 insertions, 33 deletions
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 91b5842a..18e0e6f4 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -58,7 +58,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 13de7990..c81d8aa4 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -46,9 +46,17 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) +def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE + + def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_tokens(model, message.content) + sum(count_chat_message_tokens(model, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary @@ -74,37 +82,58 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: message.content = message.summary i += 1 - # 4. Remove entire messages in the last 5 - while total_tokens > max_tokens and len(chat_history) > 0: + # 4. Remove entire messages in the last 5, except last 1 + while total_tokens > max_tokens and len(chat_history) > 1: message = chat_history.pop(0) total_tokens -= count_tokens(model, message.content) + # 5. Truncate last message + if total_tokens > max_tokens and len(chat_history) > 0: + message = chat_history[0] + message.content = prune_raw_prompt_from_top( + model, message.content, tokens_for_completion) + total_tokens = max_tokens + return chat_history +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + + def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: - prompt_tokens = count_tokens(model, prompt) + """ + The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message + """ + if prompt is not None: + prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) + msgs += [prompt_msg] + + if system_message is not None: + # NOTE: System message takes second precedence to user prompt, so it is placed just before + # but move back to start after processing + rendered_system_message = render_system_message(system_message) + system_chat_msg = ChatMessage( + role="system", content=rendered_system_message, summary=rendered_system_message) + # insert at second-to-last position + msgs.insert(-1, system_chat_msg) + + # Add tokens from functions + function_tokens = 0 if functions is not None: for function in functions: - prompt_tokens += count_tokens(model, json.dumps(function)) - - rendered_system_message = render_system_message(system_message) - - msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message)) - history = [] - if system_message: - history.append({ - "role": "system", - "content": rendered_system_message - }) - history += [msg.to_dict(with_functions=functions is not None) - for msg in msgs] - if prompt: - history.append({ - "role": "user", - "content": prompt - }) + function_tokens += count_tokens(model, json.dumps(function)) + + msgs = prune_chat_history( + model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + + history = [msg.to_dict(with_functions=functions is not None) + for msg in msgs] + + # Move system message back to start + if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": + system_message_dict = history.pop(-2) + history.insert(0, system_message_dict) return history diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 7c6b42db..8c03969e 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -35,23 +35,24 @@ class SimpleChatStep(Step): if sdk.current_step_was_deleted(): # So that the message doesn't disappear self.hide = False - return + break if "content" in chunk: self.description += chunk["content"] completion += chunk["content"] await sdk.update_ui() finally: - await generator.aclose() + self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + f"Write a short title for the following chat message: {self.description}")) - self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( - f"Write a short title for the following chat message: {self.description}")) + self.chat_context.append(ChatMessage( + role="assistant", + content=completion, + summary=self.name + )) - self.chat_context.append(ChatMessage( - role="assistant", - content=completion, - summary=self.name - )) + # TODO: Never actually closing. + await generator.aclose() class AddFileStep(Step): |