summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/libs')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
3 files changed, 8 insertions, 8 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c4e4139f..f0877d90 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -37,7 +37,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +58,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +69,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
)).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 05ece394..eab6e441 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -50,7 +50,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, None, functions=args.get("functions", None))
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -74,7 +74,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, prompt, functions=args.get("functions", None))
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 8b06fef9..73be0717 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -76,14 +76,14 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
return chat_history
-def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
prompt_tokens = count_tokens(model, prompt)
if functions is not None:
for function in functions:
prompt_tokens += count_tokens(model, json.dumps(function))
msgs = prune_chat_history(model,
- msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + DEFAULT_MAX_TOKENS + count_tokens(model, system_message))
+ msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message))
history = []
if system_message:
history.append({