diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-20 12:19:56 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-20 12:19:56 -0700 | 
| commit | dc90631c443db710e1c92a556497e403d9f9b8be (patch) | |
| tree | 6800bcd7bdc061f16461da2de7c809ce77d2090f /continuedev/src | |
| parent | 2f777ea933d4a41b600feedeff7d85257c5b136d (diff) | |
| download | sncontinue-dc90631c443db710e1c92a556497e403d9f9b8be.tar.gz sncontinue-dc90631c443db710e1c92a556497e403d9f9b8be.tar.bz2 sncontinue-dc90631c443db710e1c92a556497e403d9f9b8be.zip  | |
fix mutable default arg with_history bug
Diffstat (limited to 'continuedev/src')
7 files changed, 24 insertions, 22 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4c4de213..2766db4b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,15 +9,15 @@ from pydantic import BaseModel  class LLM(ABC):      system_message: Union[str, None] = None -    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          """Return the completion of the text with the given temperature."""          raise NotImplementedError -    def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          """Stream the completion through generator."""          raise NotImplementedError -    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          """Stream the chat through generator."""          raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index c82895c6..625d4e57 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -54,7 +54,7 @@ class AnthropicLLM(LLM):          prompt += AI_PROMPT          return prompt -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -66,7 +66,7 @@ class AnthropicLLM(LLM):          ):              yield chunk.completion -    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -83,7 +83,7 @@ class AnthropicLLM(LLM):                  "content": chunk.completion              } -    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          args = self._transform_args(args) diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 6007fdb4..4889a556 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -26,7 +26,7 @@ class GGML(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -47,7 +47,7 @@ class GGML(LLM):                          except:                              raise Exception(str(line)) -    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) @@ -72,7 +72,7 @@ class GGML(LLM):                          except:                              raise Exception(str(line[0])) -    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          async with aiohttp.ClientSession() as session: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 7e11fbbe..36f03270 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -16,7 +16,7 @@ class HuggingFaceInferenceAPI(LLM):          self.model = model          self.system_message = system_message  # TODO: Nothing being done with this -    def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): +    def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}"          headers = { diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 64bb39a2..96a4ab71 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -42,7 +42,7 @@ class OpenAI(LLM):      def count_tokens(self, text: str):          return count_tokens(self.default_model, text) -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -72,7 +72,7 @@ class OpenAI(LLM):              self.write_log(f"Completion:\n\n{completion}") -    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -93,7 +93,7 @@ class OpenAI(LLM):                  completion += chunk.choices[0].delta.content          self.write_log(f"Completion: \n\n{completion}") -    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          if args["model"] in CHAT_MODELS: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index bd50fe02..b1bb8f06 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM):      def count_tokens(self, text: str):          return count_tokens(self.default_model, text) -    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( @@ -57,7 +57,7 @@ class ProxyServer(LLM):                  except:                      raise Exception(await resp.text()) -    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) @@ -89,7 +89,7 @@ class ProxyServer(LLM):                  self.write_log(f"Completion: \n\n{completion}") -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 987aa722..6e0a3b88 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -101,13 +101,15 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:  TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:      """      The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message      """ +    msgs_copy = msgs.copy() if msgs is not None else [] +      if prompt is not None:          prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) -        msgs += [prompt_msg] +        msgs_copy += [prompt_msg]      if system_message is not None:          # NOTE: System message takes second precedence to user prompt, so it is placed just before @@ -116,7 +118,7 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,          system_chat_msg = ChatMessage(              role="system", content=rendered_system_message, summary=rendered_system_message)          # insert at second-to-last position -        msgs.insert(-1, system_chat_msg) +        msgs_copy.insert(-1, system_chat_msg)      # Add tokens from functions      function_tokens = 0 @@ -124,11 +126,11 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,          for function in functions:              function_tokens += count_tokens(model, json.dumps(function)) -    msgs = prune_chat_history( -        model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) +    msgs_copy = prune_chat_history( +        model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)      history = [msg.to_dict(with_functions=functions is not None) -               for msg in msgs] +               for msg in msgs_copy]      # Move system message back to start      if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system":  | 
