diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-08 13:44:04 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-08 13:44:04 -0700 |
commit | bbf7973ec091823c4197d59daaf151b748ee52fc (patch) | |
tree | 2347dd306d614534f10cfe48c7ca0f8db2bb61ce | |
parent | 816005ffa02636e1183d6715a4aad94835430405 (diff) | |
download | sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.gz sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.bz2 sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.zip |
feat: :sparkles: huggingface inference api llm update
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 61 |
1 files changed, 48 insertions, 13 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 49f593d8..9664fec2 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,34 +1,58 @@ -from typing import List, Optional +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +import aiohttp +import requests + +from ...core.main import ChatMessage +from ..util.count_tokens import DEFAULT_ARGS, count_tokens from ...core.main import ChatMessage from ..llm import LLM -import requests -DEFAULT_MAX_TOKENS = 2048 DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): model: str + hf_token: str + + max_context_length: int = 2048 + verify_ssl: bool = True + + _client_session: aiohttp.ClientSession = None - requires_api_key: str = "HUGGING_FACE_TOKEN" - api_key: str = None + class Config: + arbitrary_types_allowed = True - def __init__(self, model: str, system_message: str = None): - self.model = model - self.system_message = system_message # TODO: Nothing being done with this + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) - async def start(self, *, api_key: Optional[str] = None, **kwargs): - self.api_key = api_key + async def stop(self): + await self._client_session.close() - def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): + @property + def name(self): + return self.model + + @property + def context_length(self): + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" headers = { - "Authorization": f"Bearer {self.api_key}"} + "Authorization": f"Bearer {self.hf_token}"} response = requests.post(API_URL, headers=headers, json={ "inputs": prompt, "parameters": { - "max_new_tokens": DEFAULT_MAX_TOKENS, + "max_new_tokens": self.max_context_length - self.count_tokens(prompt), "max_time": DEFAULT_MAX_TIME, "return_full_text": False, } @@ -41,3 +65,14 @@ class HuggingFaceInferenceAPI(LLM): "Hugging Face returned an error response: \n\n", data) return data[0]["generated_text"] + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: + response = await self.complete(messages[-1].content, messages[:-1]) + yield { + "content": response, + "role": "assistant" + } + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]: + response = await self.complete(prompt, with_history) + yield response |