diff options
| -rw-r--r-- | .github/workflows/main.yaml | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_tgi.py | 31 | 
2 files changed, 22 insertions, 13 deletions
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index fe433ecb..de17dd8d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -17,7 +17,7 @@ jobs:        - name: Set up Python          uses: actions/setup-python@v4          with: -          python-version: "3.10" +          python-version: "3.10.8"        - name: Install dependencies          run: | @@ -64,7 +64,7 @@ jobs:        - name: Set up Python          uses: actions/setup-python@v4          with: -          python-version: "3.10" +          python-version: "3.10.8"        - name: Install Pyinstaller          run: | diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index 168ef025..8d16198d 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -1,5 +1,5 @@  import json -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List  import aiohttp  from pydantic import Field @@ -47,13 +47,22 @@ class HuggingFaceTGI(LLM):              ) as resp:                  async for line in resp.content.iter_any():                      if line: -                        chunk = line.decode("utf-8") -                        try: -                            json_chunk = json.loads(chunk) -                        except Exception as e: -                            print(f"Error parsing JSON: {e}") -                            continue -                        text = json_chunk["details"]["best_of_sequences"][0][ -                            "generated_text" -                        ] -                        yield text +                        text = line.decode("utf-8") +                        chunks = text.split("\n") + +                        for chunk in chunks: +                            if chunk.startswith("data: "): +                                chunk = chunk[len("data: ") :] +                            elif chunk.startswith("data:"): +                                chunk = chunk[len("data:") :] + +                            if chunk.strip() == "": +                                continue + +                            try: +                                json_chunk = json.loads(chunk) +                            except Exception as e: +                                print(f"Error parsing JSON: {e}") +                                continue + +                            yield json_chunk["token"]["text"]  | 
