summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm/hf_inference_api.py
blob: 734da160f6c5da2afebf1d1a658f6cbbfef675e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from ..llm import LLM
import requests

DEFAULT_MAX_TOKENS = 2048
DEFAULT_MAX_TIME = 120.


class HuggingFaceInferenceAPI(LLM):
    api_key: str
    model: str = "bigcode/starcoder"

    def complete(self, prompt: str, **kwargs):
        """Return the completion of the text with the given temperature."""
        API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
        headers = {
            "Authorization": f"Bearer {self.api_key}"}

        response = requests.post(API_URL, headers=headers, json={
            "inputs": prompt, "parameters": {
                "max_new_tokens": DEFAULT_MAX_TOKENS,
                "max_time": DEFAULT_MAX_TIME,
                "return_full_text": False,
            }
        })
        data = response.json()

        # Error if the response is not a list
        if not isinstance(data, list):
            raise Exception(
                "Hugging Face returned an error response: \n\n", data)

        return data[0]["generated_text"]