From 38fbb6a508d2df8a0de55d5dc741cc32a703a34a Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Thu, 20 Jul 2023 12:19:56 -0700 Subject: fix mutable default arg with_history bug --- continuedev/src/continuedev/libs/llm/__init__.py | 6 +++--- continuedev/src/continuedev/libs/llm/anthropic.py | 6 +++--- continuedev/src/continuedev/libs/llm/ggml.py | 6 +++--- continuedev/src/continuedev/libs/llm/hf_inference_api.py | 2 +- continuedev/src/continuedev/libs/llm/openai.py | 6 +++--- continuedev/src/continuedev/libs/llm/proxy_server.py | 6 +++--- continuedev/src/continuedev/libs/util/count_tokens.py | 14 ++++++++------ 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4c4de213..2766db4b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,15 +9,15 @@ from pydantic import BaseModel class LLM(ABC): system_message: Union[str, None] = None - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError - def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: """Stream the completion through generator.""" raise NotImplementedError - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index c82895c6..625d4e57 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -54,7 +54,7 @@ class AnthropicLLM(LLM): prompt += AI_PROMPT return prompt - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -66,7 +66,7 @@ class AnthropicLLM(LLM): ): yield chunk.completion - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -83,7 +83,7 @@ class AnthropicLLM(LLM): "content": chunk.completion } - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} args = self._transform_args(args) diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 6007fdb4..4889a556 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -26,7 +26,7 @@ class GGML(LLM): def count_tokens(self, text: str): return count_tokens(self.name, text) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -47,7 +47,7 @@ class GGML(LLM): except: raise Exception(str(line)) - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) @@ -72,7 +72,7 @@ class GGML(LLM): except: raise Exception(str(line[0])) - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} async with aiohttp.ClientSession() as session: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 7e11fbbe..36f03270 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -16,7 +16,7 @@ class HuggingFaceInferenceAPI(LLM): self.model = model self.system_message = system_message # TODO: Nothing being done with this - def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): + def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" headers = { diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 64bb39a2..96a4ab71 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -42,7 +42,7 @@ class OpenAI(LLM): def count_tokens(self, text: str): return count_tokens(self.default_model, text) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -72,7 +72,7 @@ class OpenAI(LLM): self.write_log(f"Completion:\n\n{completion}") - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -93,7 +93,7 @@ class OpenAI(LLM): completion += chunk.choices[0].delta.content self.write_log(f"Completion: \n\n{completion}") - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} if args["model"] in CHAT_MODELS: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index bd50fe02..b1bb8f06 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM): def count_tokens(self, text: str): return count_tokens(self.default_model, text) - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( @@ -57,7 +57,7 @@ class ProxyServer(LLM): except: raise Exception(await resp.text()) - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) @@ -89,7 +89,7 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 987aa722..6e0a3b88 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -101,13 +101,15 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ + msgs_copy = msgs.copy() if msgs is not None else [] + if prompt is not None: prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) - msgs += [prompt_msg] + msgs_copy += [prompt_msg] if system_message is not None: # NOTE: System message takes second precedence to user prompt, so it is placed just before @@ -116,7 +118,7 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, system_chat_msg = ChatMessage( role="system", content=rendered_system_message, summary=rendered_system_message) # insert at second-to-last position - msgs.insert(-1, system_chat_msg) + msgs_copy.insert(-1, system_chat_msg) # Add tokens from functions function_tokens = 0 @@ -124,11 +126,11 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, for function in functions: function_tokens += count_tokens(model, json.dumps(function)) - msgs = prune_chat_history( - model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + msgs_copy = prune_chat_history( + model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) history = [msg.to_dict(with_functions=functions is not None) - for msg in msgs] + for msg in msgs_copy] # Move system message back to start if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": -- cgit v1.2.3-70-g09d2