From 5316180394d48d9877cda0cb3d7c3c6de9995d12 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Thu, 14 Sep 2023 18:02:07 -0700 Subject: fix: :bug: fix huggingface tgi --- .github/workflows/main.yaml | 4 ++-- continuedev/src/continuedev/libs/llm/hf_tgi.py | 31 +++++++++++++++++--------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index fe433ecb..de17dd8d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.10.8" - name: Install dependencies run: | @@ -64,7 +64,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.10.8" - name: Install Pyinstaller run: | diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index 168ef025..8d16198d 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List import aiohttp from pydantic import Field @@ -47,13 +47,22 @@ class HuggingFaceTGI(LLM): ) as resp: async for line in resp.content.iter_any(): if line: - chunk = line.decode("utf-8") - try: - json_chunk = json.loads(chunk) - except Exception as e: - print(f"Error parsing JSON: {e}") - continue - text = json_chunk["details"]["best_of_sequences"][0][ - "generated_text" - ] - yield text + text = line.decode("utf-8") + chunks = text.split("\n") + + for chunk in chunks: + if chunk.startswith("data: "): + chunk = chunk[len("data: ") :] + elif chunk.startswith("data:"): + chunk = chunk[len("data:") :] + + if chunk.strip() == "": + continue + + try: + json_chunk = json.loads(chunk) + except Exception as e: + print(f"Error parsing JSON: {e}") + continue + + yield json_chunk["token"]["text"] -- cgit v1.2.3-70-g09d2