1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
import json
from typing import Any, Callable, List
from pydantic import Field
from ...core.main import ChatMessage
from .base import LLM, CompletionOptions
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt
class HuggingFaceTGI(LLM):
model: str = "huggingface-tgi"
server_url: str = Field(
"http://localhost:8080", description="URL of your TGI server"
)
template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
prompt_templates = {
"edit": simplified_edit_prompt,
}
class Config:
arbitrary_types_allowed = True
def collect_args(self, options: CompletionOptions) -> Any:
args = super().collect_args(options)
args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1}
args.pop("max_tokens", None)
args.pop("model", None)
args.pop("functions", None)
return args
async def _stream_complete(self, prompt, options):
args = self.collect_args(options)
async with self.create_client_session() as client_session:
async with client_session.post(
f"{self.server_url}/generate_stream",
json={"inputs": prompt, "parameters": args},
headers={"Content-Type": "application/json"},
proxy=self.proxy,
) as resp:
async for line in resp.content.iter_any():
if line:
text = line.decode("utf-8")
chunks = text.split("\n")
for chunk in chunks:
if chunk.startswith("data: "):
chunk = chunk[len("data: ") :]
elif chunk.startswith("data:"):
chunk = chunk[len("data:") :]
if chunk.strip() == "":
continue
try:
json_chunk = json.loads(chunk)
except Exception as e:
print(f"Error parsing JSON: {e}")
continue
yield json_chunk["token"]["text"]
|