summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hf_tgi.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/hf_tgi.py')
-rw-r--r--server/continuedev/libs/llm/hf_tgi.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py
new file mode 100644
index 00000000..62458db4
--- /dev/null
+++ b/server/continuedev/libs/llm/hf_tgi.py
@@ -0,0 +1,65 @@
+import json
+from typing import Any, Callable, List
+
+from pydantic import Field
+
+from ...core.main import ChatMessage
+from .base import LLM, CompletionOptions
+from .prompts.chat import llama2_template_messages
+from .prompts.edit import simplified_edit_prompt
+
+
+class HuggingFaceTGI(LLM):
+ model: str = "huggingface-tgi"
+ server_url: str = Field(
+ "http://localhost:8080", description="URL of your TGI server"
+ )
+
+ template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
+
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def collect_args(self, options: CompletionOptions) -> Any:
+ args = super().collect_args(options)
+ args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1}
+ args.pop("max_tokens", None)
+ args.pop("model", None)
+ args.pop("functions", None)
+ return args
+
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
+
+ async with self.create_client_session() as client_session:
+ async with client_session.post(
+ f"{self.server_url}/generate_stream",
+ json={"inputs": prompt, "parameters": args},
+ headers={"Content-Type": "application/json"},
+ proxy=self.proxy,
+ ) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ text = line.decode("utf-8")
+ chunks = text.split("\n")
+
+ for chunk in chunks:
+ if chunk.startswith("data: "):
+ chunk = chunk[len("data: ") :]
+ elif chunk.startswith("data:"):
+ chunk = chunk[len("data:") :]
+
+ if chunk.strip() == "":
+ continue
+
+ try:
+ json_chunk = json.loads(chunk)
+ except Exception as e:
+ print(f"Error parsing JSON: {e}")
+ continue
+
+ yield json_chunk["token"]["text"]