summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hf_inference_api.py
blob: 990ec7c8c4f3d8c894310d19317bc07be965ac93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Callable, Dict, List, Union

from huggingface_hub import InferenceClient
from pydantic import Field

from .base import LLM, CompletionOptions
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt


class HuggingFaceInferenceAPI(LLM):
    """
    Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on “New endpoint”, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking “Create Endpoint”. Change `~/.continue/config.py` to look like this:

    ```python title="~/.continue/config.py"
    from continuedev.core.models import Models
    from continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI

    config = ContinueConfig(
        ...
        models=Models(
            default=HuggingFaceInferenceAPI(
                endpoint_url="<INFERENCE_API_ENDPOINT_URL>",
                hf_token="<HUGGING_FACE_TOKEN>",
        )
    )
    ```
    """

    model: str = Field(
        "Hugging Face Inference API",
        description="The name of the model to use (optional for the HuggingFaceInferenceAPI class)",
    )
    hf_token: str = Field(..., description="Your Hugging Face API token")
    endpoint_url: str = Field(
        None, description="Your Hugging Face Inference API endpoint URL"
    )

    template_messages: Union[
        Callable[[List[Dict[str, str]]], str], None
    ] = llama2_template_messages

    prompt_templates = {
        "edit": simplified_edit_prompt,
    }

    class Config:
        arbitrary_types_allowed = True

    def collect_args(self, options: CompletionOptions):
        options.stop = None
        args = super().collect_args(options)

        if "max_tokens" in args:
            args["max_new_tokens"] = args["max_tokens"]
            del args["max_tokens"]
        if "stop" in args:
            args["stop_sequences"] = args["stop"]
            del args["stop"]

        return args

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)

        client = InferenceClient(self.endpoint_url, token=self.hf_token)

        stream = client.text_generation(prompt, stream=True, details=True, **args)

        for r in stream:
            # skip special tokens
            if r.token.special:
                continue
            # stop if we encounter a stop sequence
            if options.stop is not None:
                if r.token.text in options.stop:
                    break
            yield r.token.text