diff options
Diffstat (limited to 'server/continuedev/libs/llm/hugging_face.py')
-rw-r--r-- | server/continuedev/libs/llm/hugging_face.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/hugging_face.py b/server/continuedev/libs/llm/hugging_face.py new file mode 100644 index 00000000..c2e934c0 --- /dev/null +++ b/server/continuedev/libs/llm/hugging_face.py @@ -0,0 +1,19 @@ +# TODO: This class is far out of date + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .llm import LLM + + +class HuggingFace(LLM): + def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path) + + def complete(self, prompt: str, **kwargs): + args = {"max_tokens": 100} + args.update(kwargs) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) + return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |