summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hugging_face.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/hugging_face.py')
-rw-r--r--server/continuedev/libs/llm/hugging_face.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/hugging_face.py b/server/continuedev/libs/llm/hugging_face.py
new file mode 100644
index 00000000..c2e934c0
--- /dev/null
+++ b/server/continuedev/libs/llm/hugging_face.py
@@ -0,0 +1,19 @@
+# TODO: This class is far out of date
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from .llm import LLM
+
+
+class HuggingFace(LLM):
+ def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"):
+ self.model_path = model_path
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
+
+ def complete(self, prompt: str, **kwargs):
+ args = {"max_tokens": 100}
+ args.update(kwargs)
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
+ generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"])
+ return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)