diff options
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 86 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 1 |
3 files changed, 37 insertions, 53 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 9816d5d9..f24c81ca 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -11,6 +11,7 @@ from ..libs.llm.ollama import Ollama from ..libs.llm.openai import OpenAI from ..libs.llm.replicate import ReplicateLLM from ..libs.llm.together import TogetherLLM +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI class ContinueSDK(BaseModel): @@ -37,6 +38,7 @@ MODEL_CLASSES = { ReplicateLLM, Ollama, LlamaCpp, + HuggingFaceInferenceAPI, ] } @@ -49,6 +51,7 @@ MODEL_MODULE_NAMES = { "ReplicateLLM": "replicate", "Ollama": "ollama", "LlamaCpp": "llamacpp", + "HuggingFaceInferenceAPI": "hf_inference_api", } diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 43aac148..b8fc49a9 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,73 +1,53 @@ -from typing import List, Optional +from typing import Callable, Dict, List +from ..llm import LLM, CompletionOptions -import aiohttp -import requests - -from ...core.main import ChatMessage -from ..llm import LLM +from huggingface_hub import InferenceClient +from .prompts.chat import llama2_template_messages from .prompts.edit import simplified_edit_prompt -DEFAULT_MAX_TIME = 120.0 - class HuggingFaceInferenceAPI(LLM): + model: str = "Hugging Face Inference API" hf_token: str - self_hosted_url: str = None - - verify_ssl: Optional[bool] = None + endpoint_url: str = None - _client_session: aiohttp.ClientSession = None + template_messages: Callable[[List[Dict[str, str]]], str] | None = llama2_template_messages prompt_templates = { - "edit": simplified_edit_prompt, + "edit": simplified_edit_prompt, } class Config: arbitrary_types_allowed = True - async def start(self, **kwargs): - await super().start(**kwargs) - self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) - async def stop(self): - await self._client_session.close() + if "max_tokens" in args: + args["max_new_tokens"] = args["max_tokens"] + del args["max_tokens"] + if "stop" in args: + args["stop_sequences"] = args["stop"] + del args["stop"] + if "model" in args: + del args["model"] + return args - async def _complete(self, prompt: str, options): - """Return the completion of the text with the given temperature.""" - API_URL = ( - self.base_url or f"https://api-inference.huggingface.co/models/{self.model}" - ) - headers = {"Authorization": f"Bearer {self.hf_token}"} - response = requests.post( - API_URL, - headers=headers, - json={ - "inputs": prompt, - "parameters": { - "max_new_tokens": min( - 250, self.max_context_length - self.count_tokens(prompt) - ), - "max_time": DEFAULT_MAX_TIME, - "return_full_text": False, - }, - }, - ) - data = response.json() - - # Error if the response is not a list - if not isinstance(data, list): - raise Exception("Hugging Face returned an error response: \n\n", data) + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) - return data[0]["generated_text"] + client = InferenceClient(self.endpoint_url, token=self.hf_token) - async def _stream_chat(self, messages: List[ChatMessage], options): - response = await self._complete(messages[-1].content, messages[:-1]) - yield {"content": response, "role": "assistant"} + stream = client.text_generation(prompt, stream=True, details=True) - async def _stream_complete(self, prompt, options): - response = await self._complete(prompt, options) - yield response + for r in stream: + # skip special tokens + if r.token.special: + continue + # stop if we encounter a stop sequence + if options.stop is not None: + if r.token.text in options.stop: + break + yield r.token.text
\ No newline at end of file diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py index 3dae39c2..7fa34907 100644 --- a/continuedev/src/continuedev/plugins/steps/setup_model.py +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -13,6 +13,7 @@ MODEL_CLASS_TO_MESSAGE = { "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.", + "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect." } |