summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hf_inference_api.py
diff options
context:
space:
mode:
authorTuowen Zhao <ztuowen@gmail.com>2023-10-19 00:04:44 -0700
committerTuowen Zhao <ztuowen@gmail.com>2023-10-19 00:04:44 -0700
commit2128f5fe9386dcf2f0597c8035f951c5b60d7562 (patch)
treeac3ab65a87bd4971275ae91d7b61176eced13774 /server/continuedev/libs/llm/hf_inference_api.py
parent08f38574fa2633bbf709d24e1c79417d4285ba61 (diff)
downloadsncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.tar.gz
sncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.tar.bz2
sncontinue-2128f5fe9386dcf2f0597c8035f951c5b60d7562.zip
cleanup server
Diffstat (limited to 'server/continuedev/libs/llm/hf_inference_api.py')
-rw-r--r--server/continuedev/libs/llm/hf_inference_api.py78
1 files changed, 0 insertions, 78 deletions
diff --git a/server/continuedev/libs/llm/hf_inference_api.py b/server/continuedev/libs/llm/hf_inference_api.py
deleted file mode 100644
index 990ec7c8..00000000
--- a/server/continuedev/libs/llm/hf_inference_api.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from typing import Callable, Dict, List, Union
-
-from huggingface_hub import InferenceClient
-from pydantic import Field
-
-from .base import LLM, CompletionOptions
-from .prompts.chat import llama2_template_messages
-from .prompts.edit import simplified_edit_prompt
-
-
-class HuggingFaceInferenceAPI(LLM):
- """
- Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on “New endpoint”, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking “Create Endpoint”. Change `~/.continue/config.py` to look like this:
-
- ```python title="~/.continue/config.py"
- from continuedev.core.models import Models
- from continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-
- config = ContinueConfig(
- ...
- models=Models(
- default=HuggingFaceInferenceAPI(
- endpoint_url="<INFERENCE_API_ENDPOINT_URL>",
- hf_token="<HUGGING_FACE_TOKEN>",
- )
- )
- ```
- """
-
- model: str = Field(
- "Hugging Face Inference API",
- description="The name of the model to use (optional for the HuggingFaceInferenceAPI class)",
- )
- hf_token: str = Field(..., description="Your Hugging Face API token")
- endpoint_url: str = Field(
- None, description="Your Hugging Face Inference API endpoint URL"
- )
-
- template_messages: Union[
- Callable[[List[Dict[str, str]]], str], None
- ] = llama2_template_messages
-
- prompt_templates = {
- "edit": simplified_edit_prompt,
- }
-
- class Config:
- arbitrary_types_allowed = True
-
- def collect_args(self, options: CompletionOptions):
- options.stop = None
- args = super().collect_args(options)
-
- if "max_tokens" in args:
- args["max_new_tokens"] = args["max_tokens"]
- del args["max_tokens"]
- if "stop" in args:
- args["stop_sequences"] = args["stop"]
- del args["stop"]
-
- return args
-
- async def _stream_complete(self, prompt, options):
- args = self.collect_args(options)
-
- client = InferenceClient(self.endpoint_url, token=self.hf_token)
-
- stream = client.text_generation(prompt, stream=True, details=True, **args)
-
- for r in stream:
- # skip special tokens
- if r.token.special:
- continue
- # stop if we encounter a stop sequence
- if options.stop is not None:
- if r.token.text in options.stop:
- break
- yield r.token.text