summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py31
1 files changed, 20 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 168ef025..8d16198d 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List
import aiohttp
from pydantic import Field
@@ -47,13 +47,22 @@ class HuggingFaceTGI(LLM):
) as resp:
async for line in resp.content.iter_any():
if line:
- chunk = line.decode("utf-8")
- try:
- json_chunk = json.loads(chunk)
- except Exception as e:
- print(f"Error parsing JSON: {e}")
- continue
- text = json_chunk["details"]["best_of_sequences"][0][
- "generated_text"
- ]
- yield text
+ text = line.decode("utf-8")
+ chunks = text.split("\n")
+
+ for chunk in chunks:
+ if chunk.startswith("data: "):
+ chunk = chunk[len("data: ") :]
+ elif chunk.startswith("data:"):
+ chunk = chunk[len("data:") :]
+
+ if chunk.strip() == "":
+ continue
+
+ try:
+ json_chunk = json.loads(chunk)
+ except Exception as e:
+ print(f"Error parsing JSON: {e}")
+ continue
+
+ yield json_chunk["token"]["text"]