From f09150617ed2454f3074bcf93f53aae5ae637d40 Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:37:27 -0700 Subject: Preview (#541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Strong typing (#533) * refactor: :recycle: get rid of continuedev.src.continuedev structure * refactor: :recycle: switching back to server folder * feat: :sparkles: make config.py imports shorter * feat: :bookmark: publish as pre-release vscode extension * refactor: :recycle: refactor and add more completion params to ui * build: :building_construction: download from preview S3 * fix: :bug: fix paths * fix: :green_heart: package:pre-release * ci: :green_heart: more time for tests * fix: :green_heart: fix build scripts * fix: :bug: fix import in run.py * fix: :bookmark: update version to try again * ci: 💚 Update package.json version [skip ci] * refactor: :fire: don't check for old extensions version * fix: :bug: small bug fixes * fix: :bug: fix config.py import paths * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: platform-specific builds test #1 * feat: :green_heart: ship with binary * fix: :green_heart: fix copy statement to include.exe for windows * fix: :green_heart: cd extension before packaging * chore: :loud_sound: count tokens generated * fix: :green_heart: remove npm_config_arch * fix: :green_heart: publish as pre-release! * chore: :bookmark: update version * perf: :green_heart: hardcode distro paths * fix: :bug: fix yaml syntax error * chore: :bookmark: update version * fix: :green_heart: update permissions and version * feat: :bug: kill old server if needed * feat: :lipstick: update marketplace icon for pre-release * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: auto-reload for config.py * feat: :wrench: update default config.py imports * feat: :sparkles: codelens in config.py * feat: :sparkles: select model param count from UI * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: more model options, ollama error handling * perf: :zap: don't show server loading immediately * fix: :bug: fixing small UI details * ci: 💚 Update package.json version [skip ci] * feat: :rocket: headers param on LLM class * fix: :bug: fix headers for openai.;y * feat: :sparkles: highlight code on cmd+shift+L * ci: 💚 Update package.json version [skip ci] * feat: :lipstick: sticky top bar in gui.tsx * fix: :loud_sound: websocket logging and horizontal scrollbar * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: allow AzureOpenAI Service through GGML * ci: 💚 Update package.json version [skip ci] * fix: :bug: fix automigration * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: upload binaries in ci, download apple silicon * chore: :fire: remove notes * fix: :green_heart: use curl to download binary * fix: :green_heart: set permissions on apple silicon binary * fix: :green_heart: testing * fix: :green_heart: cleanup file * fix: :green_heart: fix preview.yaml * fix: :green_heart: only upload once per binary * fix: :green_heart: install rosetta * ci: :green_heart: download binary after tests * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: prepare ci for merge to main --------- Co-authored-by: GitHub Action --- server/continuedev/libs/util/count_tokens.py | 206 +++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 server/continuedev/libs/util/count_tokens.py (limited to 'server/continuedev/libs/util/count_tokens.py') diff --git a/server/continuedev/libs/util/count_tokens.py b/server/continuedev/libs/util/count_tokens.py new file mode 100644 index 00000000..d895a2cf --- /dev/null +++ b/server/continuedev/libs/util/count_tokens.py @@ -0,0 +1,206 @@ +import json +from typing import Dict, List, Union + +from ...core.main import ChatMessage +from .templating import render_templated_string + +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA) +aliases = { + "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", +} +DEFAULT_MAX_TOKENS = 1024 +DEFAULT_ARGS = { + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.5, +} + +already_saw_import_err = False + + +def encoding_for_model(model_name: str): + global already_saw_import_err + if already_saw_import_err: + return None + + try: + import tiktoken + from tiktoken_ext import openai_public # noqa: F401 + + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except Exception as _: + return tiktoken.encoding_for_model("gpt-3.5-turbo") + except Exception as e: + print("Error importing tiktoken", e) + already_saw_import_err = True + return None + + +def count_tokens(model_name: str, text: Union[str, None]): + if text is None: + return 0 + encoding = encoding_for_model(model_name) + if encoding is None: + # Make a safe estimate given that tokens are usually typically ~4 characters on average + return len(text) // 2 + return len(encoding.encode(text, disallowed_special=())) + + +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top( + model_name: str, context_length: int, prompt: str, tokens_for_completion: int +): + max_tokens = context_length - tokens_for_completion + encoding = encoding_for_model(model_name) + + if encoding is None: + desired_length_in_chars = max_tokens * 2 + return prompt[-desired_length_in_chars:] + + tokens = encoding.encode(prompt, disallowed_special=()) + if len(tokens) <= max_tokens: + return prompt + else: + return encoding.decode(tokens[-max_tokens:]) + + +def prune_chat_history( + model_name: str, + chat_history: List[ChatMessage], + context_length: int, + tokens_for_completion: int, +): + total_tokens = tokens_for_completion + sum( + count_chat_message_tokens(model_name, message) for message in chat_history + ) + + # 1. Replace beyond last 5 messages with summary + i = 0 + while total_tokens > context_length and i < len(chat_history) - 5: + message = chat_history[0] + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) + message.content = message.summary + i += 1 + + # 2. Remove entire messages until the last 5 + while ( + len(chat_history) > 5 + and total_tokens > context_length + and len(chat_history) > 0 + ): + message = chat_history.pop(0) + total_tokens -= count_tokens(model_name, message.content) + + # 3. Truncate message in the last 5, except last 1 + i = 0 + while ( + total_tokens > context_length + and len(chat_history) > 0 + and i < len(chat_history) - 1 + ): + message = chat_history[i] + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) + message.content = message.summary + i += 1 + + # 4. Remove entire messages in the last 5, except last 1 + while total_tokens > context_length and len(chat_history) > 1: + message = chat_history.pop(0) + total_tokens -= count_tokens(model_name, message.content) + + # 5. Truncate last message + if total_tokens > context_length and len(chat_history) > 0: + message = chat_history[0] + message.content = prune_raw_prompt_from_top( + model_name, context_length, message.content, tokens_for_completion + ) + total_tokens = context_length + + return chat_history + + +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + + +def compile_chat_messages( + model_name: str, + msgs: Union[List[ChatMessage], None], + context_length: int, + max_tokens: int, + prompt: Union[str, None] = None, + functions: Union[List, None] = None, + system_message: Union[str, None] = None, +) -> List[Dict]: + """ + The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message + """ + + msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else [] + + if prompt is not None: + prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) + msgs_copy += [prompt_msg] + + if system_message is not None and system_message.strip() != "": + # NOTE: System message takes second precedence to user prompt, so it is placed just before + # but move back to start after processing + rendered_system_message = render_templated_string(system_message) + system_chat_msg = ChatMessage( + role="system", + content=rendered_system_message, + summary=rendered_system_message, + ) + # insert at second-to-last position + msgs_copy.insert(-1, system_chat_msg) + + # Add tokens from functions + function_tokens = 0 + if functions is not None: + for function in functions: + function_tokens += count_tokens(model_name, json.dumps(function)) + + if max_tokens + function_tokens + TOKEN_BUFFER_FOR_SAFETY >= context_length: + raise ValueError( + f"max_tokens ({max_tokens}) is too close to context_length ({context_length}), which doesn't leave room for chat history. This would cause incoherent responses. Try increasing the context_length parameter of the model in your config file." + ) + + msgs_copy = prune_chat_history( + model_name, + msgs_copy, + context_length, + function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY, + ) + + history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] + + # Move system message back to start + if ( + system_message is not None + and len(history) >= 2 + and history[-2]["role"] == "system" + ): + system_message_dict = history.pop(-2) + history.insert(0, system_message_dict) + + return history + + +def format_chat_messages(messages: List[ChatMessage]) -> str: + formatted = "" + for msg in messages: + formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n" + return formatted -- cgit v1.2.3-70-g09d2