summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/models.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py86
-rw-r--r--continuedev/src/continuedev/plugins/steps/setup_model.py1
3 files changed, 37 insertions, 53 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index 9816d5d9..f24c81ca 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -11,6 +11,7 @@ from ..libs.llm.ollama import Ollama
from ..libs.llm.openai import OpenAI
from ..libs.llm.replicate import ReplicateLLM
from ..libs.llm.together import TogetherLLM
+from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
class ContinueSDK(BaseModel):
@@ -37,6 +38,7 @@ MODEL_CLASSES = {
ReplicateLLM,
Ollama,
LlamaCpp,
+ HuggingFaceInferenceAPI,
]
}
@@ -49,6 +51,7 @@ MODEL_MODULE_NAMES = {
"ReplicateLLM": "replicate",
"Ollama": "ollama",
"LlamaCpp": "llamacpp",
+ "HuggingFaceInferenceAPI": "hf_inference_api",
}
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 43aac148..b8fc49a9 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,73 +1,53 @@
-from typing import List, Optional
+from typing import Callable, Dict, List
+from ..llm import LLM, CompletionOptions
-import aiohttp
-import requests
-
-from ...core.main import ChatMessage
-from ..llm import LLM
+from huggingface_hub import InferenceClient
+from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt
-DEFAULT_MAX_TIME = 120.0
-
class HuggingFaceInferenceAPI(LLM):
+ model: str = "Hugging Face Inference API"
hf_token: str
- self_hosted_url: str = None
-
- verify_ssl: Optional[bool] = None
+ endpoint_url: str = None
- _client_session: aiohttp.ClientSession = None
+ template_messages: Callable[[List[Dict[str, str]]], str] | None = llama2_template_messages
prompt_templates = {
- "edit": simplified_edit_prompt,
+ "edit": simplified_edit_prompt,
}
class Config:
arbitrary_types_allowed = True
- async def start(self, **kwargs):
- await super().start(**kwargs)
- self._client_session = aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
- timeout=aiohttp.ClientTimeout(total=self.timeout),
- )
+ def collect_args(self, options: CompletionOptions):
+ options.stop = None
+ args = super().collect_args(options)
- async def stop(self):
- await self._client_session.close()
+ if "max_tokens" in args:
+ args["max_new_tokens"] = args["max_tokens"]
+ del args["max_tokens"]
+ if "stop" in args:
+ args["stop_sequences"] = args["stop"]
+ del args["stop"]
+ if "model" in args:
+ del args["model"]
+ return args
- async def _complete(self, prompt: str, options):
- """Return the completion of the text with the given temperature."""
- API_URL = (
- self.base_url or f"https://api-inference.huggingface.co/models/{self.model}"
- )
- headers = {"Authorization": f"Bearer {self.hf_token}"}
- response = requests.post(
- API_URL,
- headers=headers,
- json={
- "inputs": prompt,
- "parameters": {
- "max_new_tokens": min(
- 250, self.max_context_length - self.count_tokens(prompt)
- ),
- "max_time": DEFAULT_MAX_TIME,
- "return_full_text": False,
- },
- },
- )
- data = response.json()
-
- # Error if the response is not a list
- if not isinstance(data, list):
- raise Exception("Hugging Face returned an error response: \n\n", data)
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
- return data[0]["generated_text"]
+ client = InferenceClient(self.endpoint_url, token=self.hf_token)
- async def _stream_chat(self, messages: List[ChatMessage], options):
- response = await self._complete(messages[-1].content, messages[:-1])
- yield {"content": response, "role": "assistant"}
+ stream = client.text_generation(prompt, stream=True, details=True)
- async def _stream_complete(self, prompt, options):
- response = await self._complete(prompt, options)
- yield response
+ for r in stream:
+ # skip special tokens
+ if r.token.special:
+ continue
+ # stop if we encounter a stop sequence
+ if options.stop is not None:
+ if r.token.text in options.stop:
+ break
+ yield r.token.text \ No newline at end of file
diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py
index 3dae39c2..7fa34907 100644
--- a/continuedev/src/continuedev/plugins/steps/setup_model.py
+++ b/continuedev/src/continuedev/plugins/steps/setup_model.py
@@ -13,6 +13,7 @@ MODEL_CLASS_TO_MESSAGE = {
"GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",
"TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.",
"LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.",
+ "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect."
}