summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/hugging_face.py
blob: c2e934c090efadada729b558a3f0f824d199ab0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# TODO: This class is far out of date

from transformers import AutoModelForCausalLM, AutoTokenizer

from .llm import LLM


class HuggingFace(LLM):
    def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"):
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(model_path)

    def complete(self, prompt: str, **kwargs):
        args = {"max_tokens": 100}
        args.update(kwargs)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"])
        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)