diff options
Diffstat (limited to 'server/continuedev/libs/llm')
-rw-r--r-- | server/continuedev/libs/llm/__init__.py | 10 | ||||
-rw-r--r-- | server/continuedev/libs/llm/anthropic.py | 74 | ||||
-rw-r--r-- | server/continuedev/libs/llm/ggml.py | 226 | ||||
-rw-r--r-- | server/continuedev/libs/llm/google_palm_api.py | 50 | ||||
-rw-r--r-- | server/continuedev/libs/llm/hf_inference_api.py | 78 | ||||
-rw-r--r-- | server/continuedev/libs/llm/hf_tgi.py | 65 | ||||
-rw-r--r-- | server/continuedev/libs/llm/hugging_face.py | 19 | ||||
-rw-r--r-- | server/continuedev/libs/llm/openai.py | 156 | ||||
-rw-r--r-- | server/continuedev/libs/llm/openai_free_trial.py | 83 | ||||
-rw-r--r-- | server/continuedev/libs/llm/replicate.py | 78 | ||||
-rw-r--r-- | server/continuedev/libs/llm/text_gen_interface.py | 114 | ||||
-rw-r--r-- | server/continuedev/libs/llm/together.py | 125 |
12 files changed, 0 insertions, 1078 deletions
diff --git a/server/continuedev/libs/llm/__init__.py b/server/continuedev/libs/llm/__init__.py index 829ffede..7ac92059 100644 --- a/server/continuedev/libs/llm/__init__.py +++ b/server/continuedev/libs/llm/__init__.py @@ -1,14 +1,4 @@ -from .anthropic import AnthropicLLM # noqa: F401 -from .ggml import GGML # noqa: F401 -from .google_palm_api import GooglePaLMAPI # noqa: F401 -from .hf_inference_api import HuggingFaceInferenceAPI # noqa: F401 -from .hf_tgi import HuggingFaceTGI # noqa: F401 from .llamacpp import LlamaCpp # noqa: F401 from .ollama import Ollama # noqa: F401 -from .openai import OpenAI # noqa: F401 -from .openai_free_trial import OpenAIFreeTrial # noqa: F401 from .proxy_server import ProxyServer # noqa: F401 from .queued import QueuedLLM # noqa: F401 -from .replicate import ReplicateLLM # noqa: F401 -from .text_gen_interface import TextGenUI # noqa: F401 -from .together import TogetherLLM # noqa: F401 diff --git a/server/continuedev/libs/llm/anthropic.py b/server/continuedev/libs/llm/anthropic.py deleted file mode 100644 index 7d0708f1..00000000 --- a/server/continuedev/libs/llm/anthropic.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Any, Callable, Coroutine - -from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic - -from .base import LLM, CompletionOptions -from .prompts.chat import anthropic_template_messages - - -class AnthropicLLM(LLM): - """ - Import the `AnthropicLLM` class and set it as the default model: - - ```python title="~/.continue/config.py" - from continuedev.libs.llm.anthropic import AnthropicLLM - - config = ContinueConfig( - ... - models=Models( - default=AnthropicLLM(api_key="<API_KEY>", model="claude-2") - ) - ) - ``` - - Claude 2 is not yet publicly released. You can request early access [here](https://www.anthropic.com/earlyaccess). - - """ - - api_key: str - "Anthropic API key" - - model: str = "claude-2" - - _async_client: AsyncAnthropic = None - - template_messages: Callable = anthropic_template_messages - - class Config: - arbitrary_types_allowed = True - - async def start(self, **kwargs): - await super().start(**kwargs) - self._async_client = AsyncAnthropic(api_key=self.api_key) - - if self.model == "claude-2": - self.context_length = 100_000 - - def collect_args(self, options: CompletionOptions): - options.stop = None - args = super().collect_args(options) - - if "max_tokens" in args: - args["max_tokens_to_sample"] = args["max_tokens"] - del args["max_tokens"] - if "frequency_penalty" in args: - del args["frequency_penalty"] - if "presence_penalty" in args: - del args["presence_penalty"] - return args - - async def _stream_complete(self, prompt: str, options): - args = self.collect_args(options) - prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" - - async for chunk in await self._async_client.completions.create( - prompt=prompt, stream=True, **args - ): - yield chunk.completion - - async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: - args = self.collect_args(options) - prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" - return ( - await self._async_client.completions.create(prompt=prompt, **args) - ).completion diff --git a/server/continuedev/libs/llm/ggml.py b/server/continuedev/libs/llm/ggml.py deleted file mode 100644 index 55d580a8..00000000 --- a/server/continuedev/libs/llm/ggml.py +++ /dev/null @@ -1,226 +0,0 @@ -import json -from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional - -from pydantic import Field - -from ...core.main import ChatMessage -from ..util.logging import logger -from .base import LLM, CompletionOptions -from .openai import CHAT_MODELS -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplified_edit_prompt - - -class GGML(LLM): - """ - See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline. - - Once the model is running on localhost:8000, change `~/.continue/config.py` to look like this: - - ```python title="~/.continue/config.py" - from continuedev.libs.llm.ggml import GGML - - config = ContinueConfig( - ... - models=Models( - default=GGML( - max_context_length=2048, - server_url="http://localhost:8000") - ) - ) - ``` - """ - - server_url: str = Field( - "http://localhost:8000", - description="URL of the OpenAI-compatible server where the model is being served", - ) - model: str = Field( - "ggml", description="The name of the model to use (optional for the GGML class)" - ) - - api_base: Optional[str] = Field(None, description="OpenAI API base URL.") - - api_type: Optional[Literal["azure", "openai"]] = Field( - None, description="OpenAI API type." - ) - - api_version: Optional[str] = Field( - None, description="OpenAI API version. For use with Azure OpenAI Service." - ) - - engine: Optional[str] = Field( - None, description="OpenAI engine. For use with Azure OpenAI Service." - ) - - template_messages: Optional[ - Callable[[List[Dict[str, str]]], str] - ] = llama2_template_messages - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - class Config: - arbitrary_types_allowed = True - - def get_headers(self): - headers = { - "Content-Type": "application/json", - } - if self.api_key is not None: - if self.api_type == "azure": - headers["api-key"] = self.api_key - else: - headers["Authorization"] = f"Bearer {self.api_key}" - - return headers - - def get_full_server_url(self, endpoint: str): - endpoint = endpoint.lstrip("/").rstrip("/") - - if self.api_type == "azure": - if self.engine is None or self.api_version is None or self.api_base is None: - raise Exception( - "For Azure OpenAI Service, you must specify engine, api_version, and api_base." - ) - - return f"{self.api_base}/openai/deployments/{self.engine}/{endpoint}?api-version={self.api_version}" - else: - return f"{self.server_url}/v1/{endpoint}" - - async def _raw_stream_complete(self, prompt, options): - args = self.collect_args(options) - - async with self.create_client_session() as client_session: - async with client_session.post( - self.get_full_server_url(endpoint="completions"), - json={ - "prompt": prompt, - "stream": True, - **args, - }, - headers=self.get_headers(), - proxy=self.proxy, - ) as resp: - if resp.status != 200: - raise Exception( - f"Error calling /chat/completions endpoint: {resp.status}" - ) - - async for line in resp.content.iter_any(): - if line: - chunks = line.decode("utf-8") - for chunk in chunks.split("\n"): - if ( - chunk.startswith(": ping - ") - or chunk.startswith("data: [DONE]") - or chunk.strip() == "" - ): - continue - elif chunk.startswith("data: "): - chunk = chunk[6:] - try: - j = json.loads(chunk) - except Exception: - continue - if ( - "choices" in j - and len(j["choices"]) > 0 - and "text" in j["choices"][0] - ): - yield j["choices"][0]["text"] - - async def _stream_chat(self, messages: List[ChatMessage], options): - args = self.collect_args(options) - - async def generator(): - async with self.create_client_session() as client_session: - async with client_session.post( - self.get_full_server_url(endpoint="chat/completions"), - json={"messages": messages, "stream": True, **args}, - headers=self.get_headers(), - proxy=self.proxy, - ) as resp: - if resp.status != 200: - raise Exception( - f"Error calling /chat/completions endpoint: {resp.status}" - ) - - async for line, end in resp.content.iter_chunks(): - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if ( - chunk.strip() == "" - or json_chunk.startswith(": ping - ") - or json_chunk.startswith("data: [DONE]") - ): - continue - try: - yield json.loads(chunk[6:])["choices"][0]["delta"] - except: - pass - - # Because quite often the first attempt fails, and it works thereafter - try: - async for chunk in generator(): - yield chunk - except Exception as e: - logger.warning(f"Error calling /chat/completions endpoint: {e}") - async for chunk in generator(): - yield chunk - - async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: - args = self.collect_args(options) - - async with self.create_client_session() as client_session: - async with client_session.post( - self.get_full_server_url(endpoint="completions"), - json={ - "prompt": prompt, - **args, - }, - headers=self.get_headers(), - proxy=self.proxy, - ) as resp: - if resp.status != 200: - raise Exception( - f"Error calling /chat/completions endpoint: {resp.status}" - ) - - text = await resp.text() - try: - completion = json.loads(text)["choices"][0]["text"] - return completion - except Exception as e: - raise Exception( - f"Error calling /completion endpoint: {e}\n\nResponse text: {text}" - ) - - async def _complete(self, prompt: str, options: CompletionOptions): - completion = "" - if self.model in CHAT_MODELS: - async for chunk in self._stream_chat( - [{"role": "user", "content": prompt}], options - ): - if "content" in chunk: - completion += chunk["content"] - - else: - async for chunk in self._raw_stream_complete(prompt, options): - completion += chunk - - return completion - - async def _stream_complete(self, prompt, options: CompletionOptions): - if self.model in CHAT_MODELS: - async for chunk in self._stream_chat( - [{"role": "user", "content": prompt}], options - ): - if "content" in chunk: - yield chunk["content"] - - else: - async for chunk in self._raw_stream_complete(prompt, options): - yield chunk diff --git a/server/continuedev/libs/llm/google_palm_api.py b/server/continuedev/libs/llm/google_palm_api.py deleted file mode 100644 index 3379fefe..00000000 --- a/server/continuedev/libs/llm/google_palm_api.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List - -import requests -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM - - -class GooglePaLMAPI(LLM): - """ - The Google PaLM API is currently in public preview, so production applications are not supported yet. However, you can [create an API key in Google MakerSuite](https://makersuite.google.com/u/2/app/apikey) and begin trying out the `chat-bison-001` model. Change `~/.continue/config.py` to look like this: - - ```python title="~/.continue/config.py" - from continuedev.core.models import Models - from continuedev.libs.llm.hf_inference_api import GooglePaLMAPI - - config = ContinueConfig( - ... - models=Models( - default=GooglePaLMAPI( - model="chat-bison-001" - api_key="<MAKERSUITE_API_KEY>", - ) - ) - ``` - """ - - api_key: str = Field(..., description="Google PaLM API key") - - model: str = "chat-bison-001" - - async def _stream_complete(self, prompt, options): - api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" - body = {"prompt": {"messages": [{"content": prompt}]}} - response = requests.post(api_url, json=body) - yield response.json()["candidates"][0]["content"] - - async def _stream_chat(self, messages: List[ChatMessage], options): - msg_lst = [] - for message in messages: - msg_lst.append({"content": message["content"]}) - - api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" - body = {"prompt": {"messages": msg_lst}} - response = requests.post(api_url, json=body) - yield { - "content": response.json()["candidates"][0]["content"], - "role": "assistant", - } diff --git a/server/continuedev/libs/llm/hf_inference_api.py b/server/continuedev/libs/llm/hf_inference_api.py deleted file mode 100644 index 990ec7c8..00000000 --- a/server/continuedev/libs/llm/hf_inference_api.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Callable, Dict, List, Union - -from huggingface_hub import InferenceClient -from pydantic import Field - -from .base import LLM, CompletionOptions -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplified_edit_prompt - - -class HuggingFaceInferenceAPI(LLM): - """ - Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on “New endpoint”, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking “Create Endpoint”. Change `~/.continue/config.py` to look like this: - - ```python title="~/.continue/config.py" - from continuedev.core.models import Models - from continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI - - config = ContinueConfig( - ... - models=Models( - default=HuggingFaceInferenceAPI( - endpoint_url="<INFERENCE_API_ENDPOINT_URL>", - hf_token="<HUGGING_FACE_TOKEN>", - ) - ) - ``` - """ - - model: str = Field( - "Hugging Face Inference API", - description="The name of the model to use (optional for the HuggingFaceInferenceAPI class)", - ) - hf_token: str = Field(..., description="Your Hugging Face API token") - endpoint_url: str = Field( - None, description="Your Hugging Face Inference API endpoint URL" - ) - - template_messages: Union[ - Callable[[List[Dict[str, str]]], str], None - ] = llama2_template_messages - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - class Config: - arbitrary_types_allowed = True - - def collect_args(self, options: CompletionOptions): - options.stop = None - args = super().collect_args(options) - - if "max_tokens" in args: - args["max_new_tokens"] = args["max_tokens"] - del args["max_tokens"] - if "stop" in args: - args["stop_sequences"] = args["stop"] - del args["stop"] - - return args - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - - client = InferenceClient(self.endpoint_url, token=self.hf_token) - - stream = client.text_generation(prompt, stream=True, details=True, **args) - - for r in stream: - # skip special tokens - if r.token.special: - continue - # stop if we encounter a stop sequence - if options.stop is not None: - if r.token.text in options.stop: - break - yield r.token.text diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py deleted file mode 100644 index 62458db4..00000000 --- a/server/continuedev/libs/llm/hf_tgi.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from typing import Any, Callable, List - -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM, CompletionOptions -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplified_edit_prompt - - -class HuggingFaceTGI(LLM): - model: str = "huggingface-tgi" - server_url: str = Field( - "http://localhost:8080", description="URL of your TGI server" - ) - - template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - class Config: - arbitrary_types_allowed = True - - def collect_args(self, options: CompletionOptions) -> Any: - args = super().collect_args(options) - args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} - args.pop("max_tokens", None) - args.pop("model", None) - args.pop("functions", None) - return args - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - - async with self.create_client_session() as client_session: - async with client_session.post( - f"{self.server_url}/generate_stream", - json={"inputs": prompt, "parameters": args}, - headers={"Content-Type": "application/json"}, - proxy=self.proxy, - ) as resp: - async for line in resp.content.iter_any(): - if line: - text = line.decode("utf-8") - chunks = text.split("\n") - - for chunk in chunks: - if chunk.startswith("data: "): - chunk = chunk[len("data: ") :] - elif chunk.startswith("data:"): - chunk = chunk[len("data:") :] - - if chunk.strip() == "": - continue - - try: - json_chunk = json.loads(chunk) - except Exception as e: - print(f"Error parsing JSON: {e}") - continue - - yield json_chunk["token"]["text"] diff --git a/server/continuedev/libs/llm/hugging_face.py b/server/continuedev/libs/llm/hugging_face.py deleted file mode 100644 index c2e934c0..00000000 --- a/server/continuedev/libs/llm/hugging_face.py +++ /dev/null @@ -1,19 +0,0 @@ -# TODO: This class is far out of date - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from .llm import LLM - - -class HuggingFace(LLM): - def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"): - self.model_path = model_path - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - self.model = AutoModelForCausalLM.from_pretrained(model_path) - - def complete(self, prompt: str, **kwargs): - args = {"max_tokens": 100} - args.update(kwargs) - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) - return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/server/continuedev/libs/llm/openai.py b/server/continuedev/libs/llm/openai.py deleted file mode 100644 index ba29279b..00000000 --- a/server/continuedev/libs/llm/openai.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Callable, List, Literal, Optional - -import certifi -import openai -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM - -CHAT_MODELS = { - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-4", - "gpt-3.5-turbo-0613", - "gpt-4-32k", -} -MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo-16k": 16_384, - "gpt-4": 8192, - "gpt-35-turbo-16k": 16_384, - "gpt-35-turbo-0613": 4096, - "gpt-35-turbo": 4096, - "gpt-4-32k": 32_768, -} - - -class OpenAI(LLM): - """ - The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo. - - If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this: - - ```python title="~/.continue/config.py" - from continuedev.libs.llm.openai import OpenAI - - config = ContinueConfig( - ... - models=Models( - default=OpenAI( - api_key="EMPTY", - model="<MODEL_NAME>", - api_base="http://localhost:8000", # change to your server - ) - ) - ) - ``` - - Options for serving models locally with an OpenAI-compatible server include: - - - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation) - - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) - - [LocalAI](https://localai.io/basics/getting_started/) - - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server) - """ - - api_key: str = Field( - ..., - description="OpenAI API key", - ) - - proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.") - - api_base: Optional[str] = Field(None, description="OpenAI API base URL.") - - api_type: Optional[Literal["azure", "openai"]] = Field( - None, description="OpenAI API type." - ) - - api_version: Optional[str] = Field( - None, description="OpenAI API version. For use with Azure OpenAI Service." - ) - - engine: Optional[str] = Field( - None, description="OpenAI engine. For use with Azure OpenAI Service." - ) - - async def start( - self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None - ): - await super().start(write_log=write_log, unique_id=unique_id) - - if self.context_length is None: - self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096) - - openai.api_key = self.api_key - if self.api_type is not None: - openai.api_type = self.api_type - if self.api_base is not None: - openai.api_base = self.api_base - if self.api_version is not None: - openai.api_version = self.api_version - - if self.verify_ssl is not None and self.verify_ssl is False: - openai.verify_ssl_certs = False - - if self.proxy is not None: - openai.proxy = self.proxy - - openai.ca_bundle_path = self.ca_bundle_path or certifi.where() - - def collect_args(self, options): - args = super().collect_args(options) - if self.engine is not None: - args["engine"] = self.engine - - if not args["model"].endswith("0613") and "functions" in args: - del args["functions"] - - return args - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - args["stream"] = True - - if args["model"] in CHAT_MODELS: - async for chunk in await openai.ChatCompletion.acreate( - messages=[{"role": "user", "content": prompt}], - **args, - headers=self.headers, - ): - if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta: - yield chunk.choices[0].delta.content - else: - async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers): - if len(chunk.choices) > 0: - yield chunk.choices[0].text - - async def _stream_chat(self, messages: List[ChatMessage], options): - args = self.collect_args(options) - - async for chunk in await openai.ChatCompletion.acreate( - messages=messages, - stream=True, - **args, - headers=self.headers, - ): - if not hasattr(chunk, "choices") or len(chunk.choices) == 0: - continue - yield chunk.choices[0].delta - - async def _complete(self, prompt: str, options): - args = self.collect_args(options) - - if args["model"] in CHAT_MODELS: - resp = await openai.ChatCompletion.acreate( - messages=[{"role": "user", "content": prompt}], - **args, - headers=self.headers, - ) - return resp.choices[0].message.content - else: - return ( - (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text - ) diff --git a/server/continuedev/libs/llm/openai_free_trial.py b/server/continuedev/libs/llm/openai_free_trial.py deleted file mode 100644 index b6e707f9..00000000 --- a/server/continuedev/libs/llm/openai_free_trial.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Callable, List, Optional - -from ...core.main import ChatMessage -from .base import LLM -from .openai import OpenAI -from .proxy_server import ProxyServer - - -class OpenAIFreeTrial(LLM): - """ - With the `OpenAIFreeTrial` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. - - Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: - - 1. Copy your API key from https://platform.openai.com/account/api-keys - 2. Open `~/.continue/config.py`. You can do this by using the '/config' command in Continue - 3. Change the default LLMs to look like this: - - ```python title="~/.continue/config.py" - API_KEY = "<API_KEY>" - config = ContinueConfig( - ... - models=Models( - default=OpenAIFreeTrial(model="gpt-4", api_key=API_KEY), - summarize=OpenAIFreeTrial(model="gpt-3.5-turbo", api_key=API_KEY) - ) - ) - ``` - - The `OpenAIFreeTrial` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. - - These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k". - """ - - api_key: Optional[str] = None - - llm: Optional[LLM] = None - - def update_llm_properties(self): - if self.llm is not None: - self.llm.system_message = self.system_message - - async def start( - self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None - ): - await super().start(write_log=write_log, unique_id=unique_id) - if self.api_key is None or self.api_key.strip() == "": - self.llm = ProxyServer( - model=self.model, - verify_ssl=self.verify_ssl, - ca_bundle_path=self.ca_bundle_path, - ) - else: - self.llm = OpenAI( - api_key=self.api_key, - model=self.model, - verify_ssl=self.verify_ssl, - ca_bundle_path=self.ca_bundle_path, - ) - - await self.llm.start(write_log=write_log, unique_id=unique_id) - - async def stop(self): - await self.llm.stop() - - async def _complete(self, prompt: str, options): - self.update_llm_properties() - return await self.llm._complete(prompt, options) - - async def _stream_complete(self, prompt, options): - self.update_llm_properties() - resp = self.llm._stream_complete(prompt, options) - async for item in resp: - yield item - - async def _stream_chat(self, messages: List[ChatMessage], options): - self.update_llm_properties() - resp = self.llm._stream_chat(messages=messages, options=options) - async for item in resp: - yield item - - def count_tokens(self, text: str): - return self.llm.count_tokens(text) diff --git a/server/continuedev/libs/llm/replicate.py b/server/continuedev/libs/llm/replicate.py deleted file mode 100644 index 3423193b..00000000 --- a/server/continuedev/libs/llm/replicate.py +++ /dev/null @@ -1,78 +0,0 @@ -import concurrent.futures -from typing import List - -import replicate -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM -from .prompts.edit import simplified_edit_prompt - - -class ReplicateLLM(LLM): - """ - Replicate is a great option for newly released language models or models that you've deployed through their platform. Sign up for an account [here](https://replicate.ai/), copy your API key, and then select any model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models). Change `~/.continue/config.py` to look like this: - - ```python title="~/.continue/config.py" - from continuedev.core.models import Models - from continuedev.libs.llm.replicate import ReplicateLLM - - config = ContinueConfig( - ... - models=Models( - default=ReplicateLLM( - model="replicate/codellama-13b-instruct:da5676342de1a5a335b848383af297f592b816b950a43d251a0a9edd0113604b", - api_key="my-replicate-api-key") - ) - ) - ``` - - If you don't specify the `model` parameter, it will default to `replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781`. - """ - - api_key: str = Field(..., description="Replicate API key") - - model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" - - _client: replicate.Client = None - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - async def start(self, **kwargs): - await super().start(**kwargs) - self._client = replicate.Client(api_token=self.api_key) - - async def _complete(self, prompt: str, options): - def helper(): - output = self._client.run( - self.model, input={"message": prompt, "prompt": prompt} - ) - completion = "" - for item in output: - completion += item - - return completion - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(helper) - completion = future.result() - - return completion - - async def _stream_complete(self, prompt, options): - for item in self._client.run( - self.model, input={"message": prompt, "prompt": prompt} - ): - yield item - - async def _stream_chat(self, messages: List[ChatMessage], options): - for item in self._client.run( - self.model, - input={ - "message": messages[-1]["content"], - "prompt": messages[-1]["content"], - }, - ): - yield {"content": item, "role": "assistant"} diff --git a/server/continuedev/libs/llm/text_gen_interface.py b/server/continuedev/libs/llm/text_gen_interface.py deleted file mode 100644 index 225fd3b6..00000000 --- a/server/continuedev/libs/llm/text_gen_interface.py +++ /dev/null @@ -1,114 +0,0 @@ -import json -from typing import Any, Callable, Dict, List, Union - -import websockets -from pydantic import Field - -from ...core.main import ChatMessage -from .base import LLM -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplest_edit_prompt - - -class TextGenUI(LLM): - """ - TextGenUI is a comprehensive, open-source language model UI and local server. You can set it up with an OpenAI-compatible server plugin, but if for some reason that doesn't work, you can use this class like so: - - ```python title="~/.continue/config.py" - from continuedev.libs.llm.text_gen_interface import TextGenUI - - config = ContinueConfig( - ... - models=Models( - default=TextGenUI( - model="<MODEL_NAME>", - ) - ) - ) - ``` - """ - - model: str = "text-gen-ui" - server_url: str = Field( - "http://localhost:5000", description="URL of your TextGenUI server" - ) - streaming_url: str = Field( - "http://localhost:5005", - description="URL of your TextGenUI streaming server (separate from main server URL)", - ) - - prompt_templates = { - "edit": simplest_edit_prompt, - } - - template_messages: Union[ - Callable[[List[Dict[str, str]]], str], None - ] = llama2_template_messages - - class Config: - arbitrary_types_allowed = True - - def collect_args(self, options) -> Any: - args = super().collect_args(options) - args = {**args, "max_new_tokens": options.max_tokens} - args.pop("max_tokens", None) - return args - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - - ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" - payload = json.dumps({"prompt": prompt, "stream": True, **args}) - async with websockets.connect( - f"{ws_url}/api/v1/stream", ping_interval=None - ) as websocket: - await websocket.send(payload) - - while True: - incoming_data = await websocket.recv() - incoming_data = json.loads(incoming_data) - - match incoming_data["event"]: - case "text_stream": - yield incoming_data["text"] - case "stream_end": - break - - async def _stream_chat(self, messages: List[ChatMessage], options): - args = self.collect_args(options) - - async def generator(): - ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" - history = list(map(lambda x: x["content"], messages)) - payload = json.dumps( - { - "user_input": messages[-1]["content"], - "history": {"internal": [history], "visible": [history]}, - "stream": True, - **args, - } - ) - async with websockets.connect( - f"{ws_url}/api/v1/chat-stream", ping_interval=None - ) as websocket: - await websocket.send(payload) - - prev = "" - while True: - incoming_data = await websocket.recv() - incoming_data = json.loads(incoming_data) - - match incoming_data["event"]: - case "text_stream": - visible = incoming_data["history"]["visible"][-1] - if len(visible) > 0: - yield { - "role": "assistant", - "content": visible[-1].replace(prev, ""), - } - prev = visible[-1] - case "stream_end": - break - - async for chunk in generator(): - yield chunk diff --git a/server/continuedev/libs/llm/together.py b/server/continuedev/libs/llm/together.py deleted file mode 100644 index 35b3a424..00000000 --- a/server/continuedev/libs/llm/together.py +++ /dev/null @@ -1,125 +0,0 @@ -import json -from typing import Callable - -import aiohttp -from pydantic import Field - -from ...core.main import ContinueCustomException -from ..util.logging import logger -from .base import LLM -from .prompts.chat import llama2_template_messages -from .prompts.edit import simplified_edit_prompt - - -class TogetherLLM(LLM): - """ - The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this: - - ```python title="~/.continue/config.py" - from continuedev.core.models import Models - from continuedev.libs.llm.together import TogetherLLM - - config = ContinueConfig( - ... - models=Models( - default=TogetherLLM( - api_key="<API_KEY>", - model="togethercomputer/llama-2-13b-chat" - ) - ) - ) - ``` - """ - - api_key: str = Field(..., description="Together API key") - - model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" - base_url: str = Field( - "https://api.together.xyz", - description="The base URL for your Together API instance", - ) - - _client_session: aiohttp.ClientSession = None - - template_messages: Callable = llama2_template_messages - - prompt_templates = { - "edit": simplified_edit_prompt, - } - - async def start(self, **kwargs): - await super().start(**kwargs) - self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) - - async def stop(self): - await self._client_session.close() - - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) - - async with self._client_session.post( - f"{self.base_url}/inference", - json={ - "prompt": prompt, - "stream_tokens": True, - **args, - }, - headers={"Authorization": f"Bearer {self.api_key}"}, - proxy=self.proxy, - ) as resp: - async for line in resp.content.iter_chunks(): - if line[1]: - json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith( - "data: [DONE]" - ): - continue - - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - if chunk.startswith("data: "): - chunk = chunk[6:] - if chunk == "[DONE]": - break - try: - json_chunk = json.loads(chunk) - except Exception as e: - logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}") - continue - if "choices" in json_chunk: - yield json_chunk["choices"][0]["text"] - - async def _complete(self, prompt: str, options): - args = self.collect_args(options) - - async with self._client_session.post( - f"{self.base_url}/inference", - json={"prompt": prompt, **args}, - headers={"Authorization": f"Bearer {self.api_key}"}, - proxy=self.proxy, - ) as resp: - text = await resp.text() - j = json.loads(text) - try: - if "choices" not in j["output"]: - raise Exception(text) - if "output" in j: - return j["output"]["choices"][0]["text"] - except Exception as e: - j = await resp.json() - if "error" in j: - if j["error"].startswith("invalid hexlify value"): - raise ContinueCustomException( - message=f"Invalid Together API key:\n\n{j['error']}", - title="Together API Error", - ) - else: - raise ContinueCustomException( - message=j["error"], title="Together API Error" - ) - - raise e |