summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/libs/llm')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
2 files changed, 6 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c4e4139f..f0877d90 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -37,7 +37,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +58,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +69,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
)).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 05ece394..eab6e441 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -50,7 +50,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, None, functions=args.get("functions", None))
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -74,7 +74,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, prompt, functions=args.get("functions", None))
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={