diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
| commit | 8d59100b3194cc8d122708523226968899efb5e1 (patch) | |
| tree | 88fe742114c87d6df0424f46dfc86077d716a074 /continuedev/src/continuedev/libs/llm | |
| parent | 8c00cddb9345daaf2052d3b2650fa136f39813be (diff) | |
| download | sncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.gz sncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.bz2 sncontinue-8d59100b3194cc8d122708523226968899efb5e1.zip | |
(much!) faster inference with starcoder
Diffstat (limited to 'continuedev/src/continuedev/libs/llm')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py new file mode 100644 index 00000000..83852d27 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -0,0 +1,25 @@ +from ..llm import LLM +import requests + +DEFAULT_MAX_TOKENS = 2048 +DEFAULT_MAX_TIME = 120. + + +class HuggingFaceInferenceAPI(LLM): + api_key: str + model: str = "bigcode/starcoder" + + def complete(self, prompt: str, **kwargs): + """Return the completion of the text with the given temperature.""" + API_URL = f"https://api-inference.huggingface.co/models/{self.model}" + headers = { + "Authorization": f"Bearer {self.api_key}"} + + response = requests.post(API_URL, headers=headers, json={ + "inputs": prompt, "parameters": { + "max_new_tokens": DEFAULT_MAX_TOKENS, + "max_time": DEFAULT_MAX_TIME, + "return_full_text": False, + } + }) + return response.json()[0]["generated_text"] |
