diff options
Diffstat (limited to 'continuedev/src/continuedev/libs')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 4 |
3 files changed, 8 insertions, 8 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c4e4139f..f0877d90 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -37,7 +37,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None), **args, ): if "content" in chunk.choices[0].delta: @@ -58,7 +58,7 @@ class OpenAI(LLM): async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], messages, functions=args.get("functions", None)), + args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), **args, ): yield chunk.choices[0].delta @@ -69,7 +69,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None), **args, )).choices[0].message.content else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 05ece394..eab6e441 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None), + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), "unique_id": self.unique_id, **args }) as resp: @@ -50,7 +50,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, None, functions=args.get("functions", None)) + self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -74,7 +74,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, prompt, functions=args.get("functions", None)) + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 8b06fef9..73be0717 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -76,14 +76,14 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: prompt_tokens = count_tokens(model, prompt) if functions is not None: for function in functions: prompt_tokens += count_tokens(model, json.dumps(function)) msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + DEFAULT_MAX_TOKENS + count_tokens(model, system_message)) + msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message)) history = [] if system_message: history.append({ |
