summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-13 22:14:04 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-13 22:14:04 -0700
commit275ad6f72dafdfacffd9c9b5cc4847135a30f425 (patch)
tree6174e1f67ba905fe46444a90af91b0f75f51a725 /continuedev/src
parent2955a7907dac256b0a108f7c93c16354dfbe8076 (diff)
downloadsncontinue-275ad6f72dafdfacffd9c9b5cc4847135a30f425.tar.gz
sncontinue-275ad6f72dafdfacffd9c9b5cc4847135a30f425.tar.bz2
sncontinue-275ad6f72dafdfacffd9c9b5cc4847135a30f425.zip
fix: :bug: compatibility with python 3.8
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py7
-rw-r--r--continuedev/src/continuedev/libs/llm/queued.py4
3 files changed, 10 insertions, 9 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 81c10e8e..a7771018 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import Callable, Dict, List
+from typing import Callable, Dict, List, Union
from huggingface_hub import InferenceClient
from pydantic import Field
@@ -36,9 +36,9 @@ class HuggingFaceInferenceAPI(LLM):
None, description="Your Hugging Face Inference API endpoint URL"
)
- template_messages: Callable[
- [List[Dict[str, str]]], str
- ] | None = llama2_template_messages
+ template_messages: Union[
+ Callable[[List[Dict[str, str]]], str], None
+ ] = llama2_template_messages
prompt_templates = {
"edit": simplified_edit_prompt,
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 7cd699fa..6b7f21e7 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -32,9 +32,9 @@ class HuggingFaceTGI(LLM):
def collect_args(self, options: CompletionOptions) -> Any:
args = super().collect_args(options)
args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1}
- args.pop("max_tokens")
- args.pop("model")
- args.pop("functions")
+ args.pop("max_tokens", None)
+ args.pop("model", None)
+ args.pop("functions", None)
return args
async def _stream_complete(self, prompt, options):
@@ -47,6 +47,7 @@ class HuggingFaceTGI(LLM):
async with client_session.post(
f"{self.server_url}/generate_stream",
json={"inputs": prompt, "parameters": args},
+ headers={"Content-Type": "application/json"},
) as resp:
async for line in resp.content.iter_any():
if line:
diff --git a/continuedev/src/continuedev/libs/llm/queued.py b/continuedev/src/continuedev/libs/llm/queued.py
index 11fd74d6..bbaadde6 100644
--- a/continuedev/src/continuedev/libs/llm/queued.py
+++ b/continuedev/src/continuedev/libs/llm/queued.py
@@ -1,5 +1,5 @@
import asyncio
-from typing import Any, List
+from typing import Any, List, Union
from pydantic import Field
@@ -52,7 +52,7 @@ class QueuedLLM(LLM):
self,
options: CompletionOptions,
msgs: List[ChatMessage],
- functions: List[Any] | None = None,
+ functions: Union[List[Any], None] = None,
):
return self.llm.compile_chat_messages(options, msgs, functions)