summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-08 13:44:04 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-08 13:44:04 -0700
commitbbf7973ec091823c4197d59daaf151b748ee52fc (patch)
tree2347dd306d614534f10cfe48c7ca0f8db2bb61ce
parent816005ffa02636e1183d6715a4aad94835430405 (diff)
downloadsncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.gz
sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.tar.bz2
sncontinue-bbf7973ec091823c4197d59daaf151b748ee52fc.zip
feat: :sparkles: huggingface inference api llm update
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py61
1 files changed, 48 insertions, 13 deletions
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 49f593d8..9664fec2 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,34 +1,58 @@
-from typing import List, Optional
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+import aiohttp
+import requests
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import DEFAULT_ARGS, count_tokens
from ...core.main import ChatMessage
from ..llm import LLM
-import requests
-DEFAULT_MAX_TOKENS = 2048
DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
model: str
+ hf_token: str
+
+ max_context_length: int = 2048
+ verify_ssl: bool = True
+
+ _client_session: aiohttp.ClientSession = None
- requires_api_key: str = "HUGGING_FACE_TOKEN"
- api_key: str = None
+ class Config:
+ arbitrary_types_allowed = True
- def __init__(self, model: str, system_message: str = None):
- self.model = model
- self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl))
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
- self.api_key = api_key
+ async def stop(self):
+ await self._client_session.close()
- def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
- "Authorization": f"Bearer {self.api_key}"}
+ "Authorization": f"Bearer {self.hf_token}"}
response = requests.post(API_URL, headers=headers, json={
"inputs": prompt, "parameters": {
- "max_new_tokens": DEFAULT_MAX_TOKENS,
+ "max_new_tokens": self.max_context_length - self.count_tokens(prompt),
"max_time": DEFAULT_MAX_TIME,
"return_full_text": False,
}
@@ -41,3 +65,14 @@ class HuggingFaceInferenceAPI(LLM):
"Hugging Face returned an error response: \n\n", data)
return data[0]["generated_text"]
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]:
+ response = await self.complete(messages[-1].content, messages[:-1])
+ yield {
+ "content": response,
+ "role": "assistant"
+ }
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]:
+ response = await self.complete(prompt, with_history)
+ yield response