diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-26 13:15:59 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-26 13:15:59 -0700 |
commit | 64f202c376a572d9c53a3e87607a7216124c8b35 (patch) | |
tree | 4ec6b5834856bdf14dec4ba25781a996f55af863 /continuedev/src | |
parent | a2a6f4547b591c90a62c830b92a7b3920bb13b9f (diff) | |
download | sncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.tar.gz sncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.tar.bz2 sncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.zip |
token counting with functions
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 27 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 14 |
4 files changed, 38 insertions, 23 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 3024ae61..a3ca5c80 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -36,7 +36,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, prompt, with_functions=False), + args["model"], with_history, prompt, functions=None), **args, ): if "content" in chunk.choices[0].delta: @@ -56,7 +56,7 @@ class OpenAI(LLM): async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], messages, with_functions=args["model"].endswith("0613")), + args["model"], messages, functions=args.get("functions", None)), **args, ): yield chunk.choices[0].delta @@ -67,12 +67,13 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, prompt, with_functions=False), + args["model"], with_history, prompt, functions=None), **args, )).choices[0].message.content else: resp = (await openai.Completion.acreate( - prompt=prune_raw_prompt_from_top(args["model"], prompt), + prompt=prune_raw_prompt_from_top( + args["model"], prompt, args["max_tokens"]), **args, )).choices[0].text diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 9fe6e811..ccdb2002 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -33,7 +33,7 @@ class ProxyServer(LLM): async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, prompt, with_functions=False), + "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None), "unique_id": self.unique_id, **args }) as resp: @@ -45,7 +45,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = self.default_args | kwargs messages = compile_chat_messages( - self.default_model, messages, None, with_functions=args["model"].endswith("0613")) + self.default_model, messages, None, functions=args.get("functions", None)) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -59,14 +59,17 @@ class ProxyServer(LLM): try: json_chunk = line[0].decode("utf-8") json_chunk = "{}" if json_chunk == "" else json_chunk - yield json.loads(json_chunk) + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk) except: raise Exception(str(line[0])) async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args | kwargs messages = compile_chat_messages( - self.default_model, with_history, prompt, with_functions=args["model"].endswith("0613")) + self.default_model, with_history, prompt, functions=args.get("functions", None)) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 154af5e1..047a47e4 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Union from ...core.main import ChatMessage import tiktoken @@ -5,10 +6,10 @@ import tiktoken aliases = {} DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, - "gpt-3.5-turbo-0613": 4096 - DEFAULT_MAX_TOKENS, - "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, - "gpt-4": 8192 - DEFAULT_MAX_TOKENS + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" @@ -28,8 +29,9 @@ def count_tokens(model: str, text: Union[str, None]): return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str): - max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS) +def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): + max_tokens = MAX_TOKENS_FOR_MODEL.get( + model, DEFAULT_MAX_TOKENS) - tokens_for_completion encoding = encoding_for_model(model) tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: @@ -59,8 +61,8 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: # 3. Truncate message in the last 5 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0: - message = chat_history[0] + while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history): + message = chat_history[i] total_tokens -= count_tokens(model, message.content) total_tokens += count_tokens(model, message.summary) message.content = message.summary @@ -74,8 +76,12 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, with_functions: bool = False, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: prompt_tokens = count_tokens(model, prompt) + if functions is not None: + for function in functions: + prompt_tokens += count_tokens(model, json.dumps(function)) + msgs = prune_chat_history(model, msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + 1000 + count_tokens(model, system_message)) history = [] @@ -84,7 +90,8 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str "role": "system", "content": system_message }) - history += [msg.to_dict(with_functions=with_functions) for msg in msgs] + history += [msg.to_dict(with_functions=functions is not None) + for msg in msgs] if prompt: history.append({ "role": "user", diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 24f00d36..0d82b228 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -10,7 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation from ...core.main import Step, SequentialStep -from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL +from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS import difflib @@ -211,14 +211,18 @@ class DefaultModelEditCodeStep(Step): return cur_start_line, cur_end_line + # We don't know here all of the functions being passed in. + # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. + # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. + BUFFER_FOR_FUNCTIONS = 200 + total_tokens = model_to_use.count_tokens( + full_file_contents + self._prompt + self.user_input) + DEFAULT_MAX_TOKENS + BUFFER_FOR_FUNCTIONS + model_to_use = sdk.models.default if model_to_use.name == "gpt-3.5-turbo": - if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: + if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: model_to_use = sdk.models.gpt3516k - total_tokens = model_to_use.count_tokens( - full_file_contents + self._prompt + self.user_input) - cur_start_line, cur_end_line = cut_context( model_to_use, total_tokens, cur_start_line, cur_end_line) |