summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
committerNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
commit8d59100b3194cc8d122708523226968899efb5e1 (patch)
tree88fe742114c87d6df0424f46dfc86077d716a074 /continuedev/src/continuedev/libs/llm
parent8c00cddb9345daaf2052d3b2650fa136f39813be (diff)
downloadsncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.gz
sncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.bz2
sncontinue-8d59100b3194cc8d122708523226968899efb5e1.zip
(much!) faster inference with starcoder
Diffstat (limited to 'continuedev/src/continuedev/libs/llm')
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
new file mode 100644
index 00000000..83852d27
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -0,0 +1,25 @@
+from ..llm import LLM
+import requests
+
+DEFAULT_MAX_TOKENS = 2048
+DEFAULT_MAX_TIME = 120.
+
+
+class HuggingFaceInferenceAPI(LLM):
+ api_key: str
+ model: str = "bigcode/starcoder"
+
+ def complete(self, prompt: str, **kwargs):
+ """Return the completion of the text with the given temperature."""
+ API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
+ headers = {
+ "Authorization": f"Bearer {self.api_key}"}
+
+ response = requests.post(API_URL, headers=headers, json={
+ "inputs": prompt, "parameters": {
+ "max_new_tokens": DEFAULT_MAX_TOKENS,
+ "max_time": DEFAULT_MAX_TIME,
+ "return_full_text": False,
+ }
+ })
+ return response.json()[0]["generated_text"]