diff options
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 75 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 19 | 
3 files changed, 63 insertions, 33 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 91b5842a..18e0e6f4 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -58,7 +58,7 @@ class ProxyServer(LLM):      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 13de7990..c81d8aa4 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -46,9 +46,17 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in          return encoding.decode(tokens[-max_tokens:]) +def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: +    # Doing simpler, safer version of what is here: +    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +    # every message follows <|start|>{role/name}\n{content}<|end|>\n +    TOKENS_PER_MESSAGE = 4 +    return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE + +  def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):      total_tokens = tokens_for_completion + \ -        sum(count_tokens(model, message.content) +        sum(count_chat_message_tokens(model, message)              for message in chat_history)      # 1. Replace beyond last 5 messages with summary @@ -74,37 +82,58 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:          message.content = message.summary          i += 1 -    # 4. Remove entire messages in the last 5 -    while total_tokens > max_tokens and len(chat_history) > 0: +    # 4. Remove entire messages in the last 5, except last 1 +    while total_tokens > max_tokens and len(chat_history) > 1:          message = chat_history.pop(0)          total_tokens -= count_tokens(model, message.content) +    # 5. Truncate last message +    if total_tokens > max_tokens and len(chat_history) > 0: +        message = chat_history[0] +        message.content = prune_raw_prompt_from_top( +            model, message.content, tokens_for_completion) +        total_tokens = max_tokens +      return chat_history +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + +  def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: -    prompt_tokens = count_tokens(model, prompt) +    """ +    The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message +    """ +    if prompt is not None: +        prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) +        msgs += [prompt_msg] + +    if system_message is not None: +        # NOTE: System message takes second precedence to user prompt, so it is placed just before +        # but move back to start after processing +        rendered_system_message = render_system_message(system_message) +        system_chat_msg = ChatMessage( +            role="system", content=rendered_system_message, summary=rendered_system_message) +        # insert at second-to-last position +        msgs.insert(-1, system_chat_msg) + +    # Add tokens from functions +    function_tokens = 0      if functions is not None:          for function in functions: -            prompt_tokens += count_tokens(model, json.dumps(function)) - -    rendered_system_message = render_system_message(system_message) - -    msgs = prune_chat_history(model, -                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message)) -    history = [] -    if system_message: -        history.append({ -            "role": "system", -            "content": rendered_system_message -        }) -    history += [msg.to_dict(with_functions=functions is not None) -                for msg in msgs] -    if prompt: -        history.append({ -            "role": "user", -            "content": prompt -        }) +            function_tokens += count_tokens(model, json.dumps(function)) + +    msgs = prune_chat_history( +        model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + +    history = [msg.to_dict(with_functions=functions is not None) +               for msg in msgs] + +    # Move system message back to start +    if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": +        system_message_dict = history.pop(-2) +        history.insert(0, system_message_dict)      return history diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 7c6b42db..8c03969e 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -35,23 +35,24 @@ class SimpleChatStep(Step):                  if sdk.current_step_was_deleted():                      # So that the message doesn't disappear                      self.hide = False -                    return +                    break                  if "content" in chunk:                      self.description += chunk["content"]                      completion += chunk["content"]                      await sdk.update_ui()          finally: -            await generator.aclose() +            self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( +                f"Write a short title for the following chat message: {self.description}")) -        self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( -            f"Write a short title for the following chat message: {self.description}")) +            self.chat_context.append(ChatMessage( +                role="assistant", +                content=completion, +                summary=self.name +            )) -        self.chat_context.append(ChatMessage( -            role="assistant", -            content=completion, -            summary=self.name -        )) +            # TODO: Never actually closing. +            await generator.aclose()  class AddFileStep(Step): | 
