diff options
author | Tuowen Zhao <ztuowen@gmail.com> | 2023-10-19 00:04:44 -0700 |
---|---|---|
committer | Tuowen Zhao <ztuowen@gmail.com> | 2023-10-19 00:04:44 -0700 |
commit | 2128f5fe9386dcf2f0597c8035f951c5b60d7562 (patch) | |
tree | ac3ab65a87bd4971275ae91d7b61176eced13774 /server/continuedev/libs/llm/hf_tgi.py | |
parent | 08f38574fa2633bbf709d24e1c79417d4285ba61 (diff) | |
download | sncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.tar.gz sncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.tar.bz2 sncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.zip |
cleanup server
Diffstat (limited to 'server/continuedev/libs/llm/hf_tgi.py')
-rw-r--r-- | server/continuedev/libs/llm/hf_tgi.py | 65 |
1 files changed, 0 insertions, 65 deletions
diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py deleted file mode 100644 index 62458db4..00000000 --- a/server/continuedev/libs/llm/hf_tgi.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from typing import Any, Callable, List - -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM, CompletionOptions -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplified_edit_prompt - - -class HuggingFaceTGI(LLM): - model: str = "huggingface-tgi" - server_url: str = Field( - "http://localhost:8080", description="URL of your TGI server" - ) - - template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - class Config: - arbitrary_types_allowed = True - - def collect_args(self, options: CompletionOptions) -> Any: - args = super().collect_args(options) - args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} - args.pop("max_tokens", None) - args.pop("model", None) - args.pop("functions", None) - return args - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - - async with self.create_client_session() as client_session: - async with client_session.post( - f"{self.server_url}/generate_stream", - json={"inputs": prompt, "parameters": args}, - headers={"Content-Type": "application/json"}, - proxy=self.proxy, - ) as resp: - async for line in resp.content.iter_any(): - if line: - text = line.decode("utf-8") - chunks = text.split("\n") - - for chunk in chunks: - if chunk.startswith("data: "): - chunk = chunk[len("data: ") :] - elif chunk.startswith("data:"): - chunk = chunk[len("data:") :] - - if chunk.strip() == "": - continue - - try: - json_chunk = json.loads(chunk) - except Exception as e: - print(f"Error parsing JSON: {e}") - continue - - yield json_chunk["token"]["text"] |