summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-14 18:02:07 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-14 18:02:07 -0700
commit5316180394d48d9877cda0cb3d7c3c6de9995d12 (patch)
tree623643d1e0640ab545062b8c2fb7adbf7b80551c /continuedev/src
parent03bb87b565af7b77e888f8232a95bb38ded5bbd2 (diff)
downloadsncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.tar.gz
sncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.tar.bz2
sncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.zip
fix: :bug: fix huggingface tgi
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py31
1 files changed, 20 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 168ef025..8d16198d 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List
import aiohttp
from pydantic import Field
@@ -47,13 +47,22 @@ class HuggingFaceTGI(LLM):
) as resp:
async for line in resp.content.iter_any():
if line:
- chunk = line.decode("utf-8")
- try:
- json_chunk = json.loads(chunk)
- except Exception as e:
- print(f"Error parsing JSON: {e}")
- continue
- text = json_chunk["details"]["best_of_sequences"][0][
- "generated_text"
- ]
- yield text
+ text = line.decode("utf-8")
+ chunks = text.split("\n")
+
+ for chunk in chunks:
+ if chunk.startswith("data: "):
+ chunk = chunk[len("data: ") :]
+ elif chunk.startswith("data:"):
+ chunk = chunk[len("data:") :]
+
+ if chunk.strip() == "":
+ continue
+
+ try:
+ json_chunk = json.loads(chunk)
+ except Exception as e:
+ print(f"Error parsing JSON: {e}")
+ continue
+
+ yield json_chunk["token"]["text"]