diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_tgi.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/queued.py | 4 |
3 files changed, 10 insertions, 9 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 81c10e8e..a7771018 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Union from huggingface_hub import InferenceClient from pydantic import Field @@ -36,9 +36,9 @@ class HuggingFaceInferenceAPI(LLM): None, description="Your Hugging Face Inference API endpoint URL" ) - template_messages: Callable[ - [List[Dict[str, str]]], str - ] | None = llama2_template_messages + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages prompt_templates = { "edit": simplified_edit_prompt, diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index 7cd699fa..6b7f21e7 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -32,9 +32,9 @@ class HuggingFaceTGI(LLM): def collect_args(self, options: CompletionOptions) -> Any: args = super().collect_args(options) args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} - args.pop("max_tokens") - args.pop("model") - args.pop("functions") + args.pop("max_tokens", None) + args.pop("model", None) + args.pop("functions", None) return args async def _stream_complete(self, prompt, options): @@ -47,6 +47,7 @@ class HuggingFaceTGI(LLM): async with client_session.post( f"{self.server_url}/generate_stream", json={"inputs": prompt, "parameters": args}, + headers={"Content-Type": "application/json"}, ) as resp: async for line in resp.content.iter_any(): if line: diff --git a/continuedev/src/continuedev/libs/llm/queued.py b/continuedev/src/continuedev/libs/llm/queued.py index 11fd74d6..bbaadde6 100644 --- a/continuedev/src/continuedev/libs/llm/queued.py +++ b/continuedev/src/continuedev/libs/llm/queued.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List +from typing import Any, List, Union from pydantic import Field @@ -52,7 +52,7 @@ class QueuedLLM(LLM): self, options: CompletionOptions, msgs: List[ChatMessage], - functions: List[Any] | None = None, + functions: Union[List[Any], None] = None, ): return self.llm.compile_chat_messages(options, msgs, functions) |