summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-26 13:15:59 -0700
committerNate Sesti <sestinj@gmail.com>2023-06-26 13:15:59 -0700
commit64f202c376a572d9c53a3e87607a7216124c8b35 (patch)
tree4ec6b5834856bdf14dec4ba25781a996f55af863 /continuedev/src
parenta2a6f4547b591c90a62c830b92a7b3920bb13b9f (diff)
downloadsncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.tar.gz
sncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.tar.bz2
sncontinue-64f202c376a572d9c53a3e87607a7216124c8b35.zip
token counting with functions
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py9
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py11
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py27
-rw-r--r--continuedev/src/continuedev/steps/core/core.py14
4 files changed, 38 insertions, 23 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 3024ae61..a3ca5c80 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -36,7 +36,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, with_functions=False),
+ args["model"], with_history, prompt, functions=None),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -56,7 +56,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, with_functions=args["model"].endswith("0613")),
+ args["model"], messages, functions=args.get("functions", None)),
**args,
):
yield chunk.choices[0].delta
@@ -67,12 +67,13 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, with_functions=False),
+ args["model"], with_history, prompt, functions=None),
**args,
)).choices[0].message.content
else:
resp = (await openai.Completion.acreate(
- prompt=prune_raw_prompt_from_top(args["model"], prompt),
+ prompt=prune_raw_prompt_from_top(
+ args["model"], prompt, args["max_tokens"]),
**args,
)).choices[0].text
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 9fe6e811..ccdb2002 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -33,7 +33,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, prompt, with_functions=False),
+ "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -45,7 +45,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = self.default_args | kwargs
messages = compile_chat_messages(
- self.default_model, messages, None, with_functions=args["model"].endswith("0613"))
+ self.default_model, messages, None, functions=args.get("functions", None))
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -59,14 +59,17 @@ class ProxyServer(LLM):
try:
json_chunk = line[0].decode("utf-8")
json_chunk = "{}" if json_chunk == "" else json_chunk
- yield json.loads(json_chunk)
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk)
except:
raise Exception(str(line[0]))
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args | kwargs
messages = compile_chat_messages(
- self.default_model, with_history, prompt, with_functions=args["model"].endswith("0613"))
+ self.default_model, with_history, prompt, functions=args.get("functions", None))
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 154af5e1..047a47e4 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,3 +1,4 @@
+import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
import tiktoken
@@ -5,10 +6,10 @@ import tiktoken
aliases = {}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS,
- "gpt-3.5-turbo-0613": 4096 - DEFAULT_MAX_TOKENS,
- "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS,
- "gpt-4": 8192 - DEFAULT_MAX_TOKENS
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
@@ -28,8 +29,9 @@ def count_tokens(model: str, text: Union[str, None]):
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS)
+def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
+ max_tokens = MAX_TOKENS_FOR_MODEL.get(
+ model, DEFAULT_MAX_TOKENS) - tokens_for_completion
encoding = encoding_for_model(model)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
@@ -59,8 +61,8 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
# 3. Truncate message in the last 5
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0:
- message = chat_history[0]
+ while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history):
+ message = chat_history[i]
total_tokens -= count_tokens(model, message.content)
total_tokens += count_tokens(model, message.summary)
message.content = message.summary
@@ -74,8 +76,12 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
return chat_history
-def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, with_functions: bool = False, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
prompt_tokens = count_tokens(model, prompt)
+ if functions is not None:
+ for function in functions:
+ prompt_tokens += count_tokens(model, json.dumps(function))
+
msgs = prune_chat_history(model,
msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + 1000 + count_tokens(model, system_message))
history = []
@@ -84,7 +90,8 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str
"role": "system",
"content": system_message
})
- history += [msg.to_dict(with_functions=with_functions) for msg in msgs]
+ history += [msg.to_dict(with_functions=functions is not None)
+ for msg in msgs]
if prompt:
history.append({
"role": "user",
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 24f00d36..0d82b228 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -10,7 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten
from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ...core.main import Step, SequentialStep
-from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL
+from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
import difflib
@@ -211,14 +211,18 @@ class DefaultModelEditCodeStep(Step):
return cur_start_line, cur_end_line
+ # We don't know here all of the functions being passed in.
+ # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
+ # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
+ BUFFER_FOR_FUNCTIONS = 200
+ total_tokens = model_to_use.count_tokens(
+ full_file_contents + self._prompt + self.user_input) + DEFAULT_MAX_TOKENS + BUFFER_FOR_FUNCTIONS
+
model_to_use = sdk.models.default
if model_to_use.name == "gpt-3.5-turbo":
- if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
+ if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
model_to_use = sdk.models.gpt3516k
- total_tokens = model_to_use.count_tokens(
- full_file_contents + self._prompt + self.user_input)
-
cur_start_line, cur_end_line = cut_context(
model_to_use, total_tokens, cur_start_line, cur_end_line)