blob: 734da160f6c5da2afebf1d1a658f6cbbfef675e4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
from ..llm import LLM
import requests
DEFAULT_MAX_TOKENS = 2048
DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
api_key: str
model: str = "bigcode/starcoder"
def complete(self, prompt: str, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
"Authorization": f"Bearer {self.api_key}"}
response = requests.post(API_URL, headers=headers, json={
"inputs": prompt, "parameters": {
"max_new_tokens": DEFAULT_MAX_TOKENS,
"max_time": DEFAULT_MAX_TIME,
"return_full_text": False,
}
})
data = response.json()
# Error if the response is not a list
if not isinstance(data, list):
raise Exception(
"Hugging Face returned an error response: \n\n", data)
return data[0]["generated_text"]
|