summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hf_tgi.py
blob: 62458db494e39517e038adde7a17888a6b8e4566 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
from typing import Any, Callable, List

from pydantic import Field

from ...core.main import ChatMessage
from .base import LLM, CompletionOptions
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt


class HuggingFaceTGI(LLM):
    model: str = "huggingface-tgi"
    server_url: str = Field(
        "http://localhost:8080", description="URL of your TGI server"
    )

    template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages

    prompt_templates = {
        "edit": simplified_edit_prompt,
    }

    class Config:
        arbitrary_types_allowed = True

    def collect_args(self, options: CompletionOptions) -> Any:
        args = super().collect_args(options)
        args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1}
        args.pop("max_tokens", None)
        args.pop("model", None)
        args.pop("functions", None)
        return args

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)

        async with self.create_client_session() as client_session:
            async with client_session.post(
                f"{self.server_url}/generate_stream",
                json={"inputs": prompt, "parameters": args},
                headers={"Content-Type": "application/json"},
                proxy=self.proxy,
            ) as resp:
                async for line in resp.content.iter_any():
                    if line:
                        text = line.decode("utf-8")
                        chunks = text.split("\n")

                        for chunk in chunks:
                            if chunk.startswith("data: "):
                                chunk = chunk[len("data: ") :]
                            elif chunk.startswith("data:"):
                                chunk = chunk[len("data:") :]

                            if chunk.strip() == "":
                                continue

                            try:
                                json_chunk = json.loads(chunk)
                            except Exception as e:
                                print(f"Error parsing JSON: {e}")
                                continue

                            yield json_chunk["token"]["text"]