summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-14 18:02:07 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-14 18:02:07 -0700
commit5316180394d48d9877cda0cb3d7c3c6de9995d12 (patch)
tree623643d1e0640ab545062b8c2fb7adbf7b80551c
parent03bb87b565af7b77e888f8232a95bb38ded5bbd2 (diff)
downloadsncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.tar.gz
sncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.tar.bz2
sncontinue-5316180394d48d9877cda0cb3d7c3c6de9995d12.zip
fix: :bug: fix huggingface tgi
-rw-r--r--.github/workflows/main.yaml4
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py31
2 files changed, 22 insertions, 13 deletions
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml
index fe433ecb..de17dd8d 100644
--- a/.github/workflows/main.yaml
+++ b/.github/workflows/main.yaml
@@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.10.8"
- name: Install dependencies
run: |
@@ -64,7 +64,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.10.8"
- name: Install Pyinstaller
run: |
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 168ef025..8d16198d 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List
import aiohttp
from pydantic import Field
@@ -47,13 +47,22 @@ class HuggingFaceTGI(LLM):
) as resp:
async for line in resp.content.iter_any():
if line:
- chunk = line.decode("utf-8")
- try:
- json_chunk = json.loads(chunk)
- except Exception as e:
- print(f"Error parsing JSON: {e}")
- continue
- text = json_chunk["details"]["best_of_sequences"][0][
- "generated_text"
- ]
- yield text
+ text = line.decode("utf-8")
+ chunks = text.split("\n")
+
+ for chunk in chunks:
+ if chunk.startswith("data: "):
+ chunk = chunk[len("data: ") :]
+ elif chunk.startswith("data:"):
+ chunk = chunk[len("data:") :]
+
+ if chunk.strip() == "":
+ continue
+
+ try:
+ json_chunk = json.loads(chunk)
+ except Exception as e:
+ print(f"Error parsing JSON: {e}")
+ continue
+
+ yield json_chunk["token"]["text"]