diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-08 13:44:04 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-08 13:44:04 -0700 | 
| commit | bbf7973ec091823c4197d59daaf151b748ee52fc (patch) | |
| tree | 2347dd306d614534f10cfe48c7ca0f8db2bb61ce /continuedev/src | |
| parent | 816005ffa02636e1183d6715a4aad94835430405 (diff) | |
| download | sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.gz sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.bz2 sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.zip  | |
feat: :sparkles: huggingface inference api llm update
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 61 | 
1 files changed, 48 insertions, 13 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 49f593d8..9664fec2 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,34 +1,58 @@ -from typing import List, Optional +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +import aiohttp +import requests + +from ...core.main import ChatMessage +from ..util.count_tokens import DEFAULT_ARGS, count_tokens  from ...core.main import ChatMessage  from ..llm import LLM -import requests -DEFAULT_MAX_TOKENS = 2048  DEFAULT_MAX_TIME = 120.  class HuggingFaceInferenceAPI(LLM):      model: str +    hf_token: str + +    max_context_length: int = 2048 +    verify_ssl: bool = True + +    _client_session: aiohttp.ClientSession = None -    requires_api_key: str = "HUGGING_FACE_TOKEN" -    api_key: str = None +    class Config: +        arbitrary_types_allowed = True -    def __init__(self, model: str, system_message: str = None): -        self.model = model -        self.system_message = system_message  # TODO: Nothing being done with this +    async def start(self, **kwargs): +        self._client_session = aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) -    async def start(self, *, api_key: Optional[str] = None, **kwargs): -        self.api_key = api_key +    async def stop(self): +        await self._client_session.close() -    def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): +    @property +    def name(self): +        return self.model + +    @property +    def context_length(self): +        return self.max_context_length + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + +    def count_tokens(self, text: str): +        return count_tokens(self.name, text) + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}"          headers = { -            "Authorization": f"Bearer {self.api_key}"} +            "Authorization": f"Bearer {self.hf_token}"}          response = requests.post(API_URL, headers=headers, json={              "inputs": prompt, "parameters": { -                "max_new_tokens": DEFAULT_MAX_TOKENS, +                "max_new_tokens": self.max_context_length - self.count_tokens(prompt),                  "max_time": DEFAULT_MAX_TIME,                  "return_full_text": False,              } @@ -41,3 +65,14 @@ class HuggingFaceInferenceAPI(LLM):                  "Hugging Face returned an error response: \n\n", data)          return data[0]["generated_text"] + +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: +        response = await self.complete(messages[-1].content, messages[:-1]) +        yield { +            "content": response, +            "role": "assistant" +        } + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]: +        response = await self.complete(prompt, with_history) +        yield response  | 
