diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/models.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 86 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 1 | 
3 files changed, 37 insertions, 53 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 9816d5d9..f24c81ca 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -11,6 +11,7 @@ from ..libs.llm.ollama import Ollama  from ..libs.llm.openai import OpenAI  from ..libs.llm.replicate import ReplicateLLM  from ..libs.llm.together import TogetherLLM +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  class ContinueSDK(BaseModel): @@ -37,6 +38,7 @@ MODEL_CLASSES = {          ReplicateLLM,          Ollama,          LlamaCpp, +        HuggingFaceInferenceAPI,      ]  } @@ -49,6 +51,7 @@ MODEL_MODULE_NAMES = {      "ReplicateLLM": "replicate",      "Ollama": "ollama",      "LlamaCpp": "llamacpp", +    "HuggingFaceInferenceAPI": "hf_inference_api",  } diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 43aac148..b8fc49a9 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,73 +1,53 @@ -from typing import List, Optional +from typing import Callable, Dict, List +from ..llm import LLM, CompletionOptions -import aiohttp -import requests - -from ...core.main import ChatMessage -from ..llm import LLM +from huggingface_hub import InferenceClient +from .prompts.chat import llama2_template_messages  from .prompts.edit import simplified_edit_prompt -DEFAULT_MAX_TIME = 120.0 -  class HuggingFaceInferenceAPI(LLM): +    model: str = "Hugging Face Inference API"      hf_token: str -    self_hosted_url: str = None - -    verify_ssl: Optional[bool] = None +    endpoint_url: str = None -    _client_session: aiohttp.ClientSession = None +    template_messages: Callable[[List[Dict[str, str]]], str] | None = llama2_template_messages      prompt_templates = { -        "edit": simplified_edit_prompt, +       "edit": simplified_edit_prompt,      }      class Config:          arbitrary_types_allowed = True -    async def start(self, **kwargs): -        await super().start(**kwargs) -        self._client_session = aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), -            timeout=aiohttp.ClientTimeout(total=self.timeout), -        ) +    def collect_args(self, options: CompletionOptions): +        options.stop = None +        args = super().collect_args(options) -    async def stop(self): -        await self._client_session.close() +        if "max_tokens" in args: +            args["max_new_tokens"] = args["max_tokens"] +            del args["max_tokens"] +        if "stop" in args: +            args["stop_sequences"] = args["stop"] +            del args["stop"] +        if "model" in args: +            del args["model"] +        return args -    async def _complete(self, prompt: str, options): -        """Return the completion of the text with the given temperature.""" -        API_URL = ( -            self.base_url or f"https://api-inference.huggingface.co/models/{self.model}" -        ) -        headers = {"Authorization": f"Bearer {self.hf_token}"} -        response = requests.post( -            API_URL, -            headers=headers, -            json={ -                "inputs": prompt, -                "parameters": { -                    "max_new_tokens": min( -                        250, self.max_context_length - self.count_tokens(prompt) -                    ), -                    "max_time": DEFAULT_MAX_TIME, -                    "return_full_text": False, -                }, -            }, -        ) -        data = response.json() - -        # Error if the response is not a list -        if not isinstance(data, list): -            raise Exception("Hugging Face returned an error response: \n\n", data) +    async def _stream_complete(self, prompt, options): +        args = self.collect_args(options) -        return data[0]["generated_text"] +        client = InferenceClient(self.endpoint_url, token=self.hf_token) -    async def _stream_chat(self, messages: List[ChatMessage], options): -        response = await self._complete(messages[-1].content, messages[:-1]) -        yield {"content": response, "role": "assistant"} +        stream = client.text_generation(prompt, stream=True, details=True) -    async def _stream_complete(self, prompt, options): -        response = await self._complete(prompt, options) -        yield response +        for r in stream: +            # skip special tokens +            if r.token.special: +                continue +            # stop if we encounter a stop sequence +            if options.stop is not None: +                if r.token.text in options.stop: +                    break +            yield r.token.text
\ No newline at end of file diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py index 3dae39c2..7fa34907 100644 --- a/continuedev/src/continuedev/plugins/steps/setup_model.py +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -13,6 +13,7 @@ MODEL_CLASS_TO_MESSAGE = {      "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",      "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.",      "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.", +    "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect."  }  | 
