summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hf_tgi.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/hf_tgi.py')
-rw-r--r--server/continuedev/libs/llm/hf_tgi.py65
1 files changed, 0 insertions, 65 deletions
diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py
deleted file mode 100644
index 62458db4..00000000
--- a/server/continuedev/libs/llm/hf_tgi.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import json
-from typing import Any, Callable, List
-
-from pydantic import Field
-
-from ...core.main import ChatMessage
-from .base import LLM, CompletionOptions
-from .prompts.chat import llama2_template_messages
-from .prompts.edit import simplified_edit_prompt
-
-
-class HuggingFaceTGI(LLM):
- model: str = "huggingface-tgi"
- server_url: str = Field(
- "http://localhost:8080", description="URL of your TGI server"
- )
-
- template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
-
- prompt_templates = {
- "edit": simplified_edit_prompt,
- }
-
- class Config:
- arbitrary_types_allowed = True
-
- def collect_args(self, options: CompletionOptions) -> Any:
- args = super().collect_args(options)
- args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1}
- args.pop("max_tokens", None)
- args.pop("model", None)
- args.pop("functions", None)
- return args
-
- async def _stream_complete(self, prompt, options):
- args = self.collect_args(options)
-
- async with self.create_client_session() as client_session:
- async with client_session.post(
- f"{self.server_url}/generate_stream",
- json={"inputs": prompt, "parameters": args},
- headers={"Content-Type": "application/json"},
- proxy=self.proxy,
- ) as resp:
- async for line in resp.content.iter_any():
- if line:
- text = line.decode("utf-8")
- chunks = text.split("\n")
-
- for chunk in chunks:
- if chunk.startswith("data: "):
- chunk = chunk[len("data: ") :]
- elif chunk.startswith("data:"):
- chunk = chunk[len("data:") :]
-
- if chunk.strip() == "":
- continue
-
- try:
- json_chunk = json.loads(chunk)
- except Exception as e:
- print(f"Error parsing JSON: {e}")
- continue
-
- yield json_chunk["token"]["text"]