diff options
author | Luna <git@l4.pm> | 2023-07-28 00:41:05 -0300 |
---|---|---|
committer | Luna <git@l4.pm> | 2023-07-28 00:41:05 -0300 |
commit | 902ac7f67c2bb7b9e7c17ed15582414db42b5c52 (patch) | |
tree | 3985781f18e49566c2fb0a214cc31d386843d8c9 /continuedev/src | |
parent | d9026e72caa8ff94aa066f16cef677e5de76af07 (diff) | |
download | sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.gz sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.bz2 sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.zip |
make Config receive LLM objects
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 66 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 13 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 |
8 files changed, 43 insertions, 73 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..23f4fe65 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,6 +2,7 @@ import json import os from .main import Step from .context import ContextProvider +from ..libs.llm.openai import OpenAI from pydantic import BaseModel, validator from typing import List, Literal, Optional, Dict, Type, Union import yaml @@ -38,8 +39,7 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4", "claude-2", "ggml"] = 'gpt-4' + llm: LLM = OpenAI(default_model="gpt-4") temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", @@ -49,7 +49,6 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - azure_openai_info: Optional[AzureInfo] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 992bc1cf..92a72b23 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -63,18 +63,8 @@ class Models: self.system_message = sdk.config.system_message @classmethod - async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": - if sdk.config.default_model == "claude-2": - with_providers.append("anthropic") - - models = Models(sdk, with_providers) - for provider in with_providers: - if provider in MODEL_PROVIDER_TO_ENV_VAR: - env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] - models.provider_keys[provider] = await sdk.get_user_secret( - env_var, f'Please add your {env_var} to the .env file') - - return models + async def create(cls, sdk: "ContinueSDK") -> "Models": + return self.default def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] @@ -90,54 +80,9 @@ class Models: api_key = self.provider_keys["anthropic"] return AnthropicLLM(api_key, model, self.system_message) - @cached_property - def claude2(self): - return self.__load_anthropic_model("claude-2") - - @cached_property - def starcoder(self): - return self.__load_hf_inference_api_model("bigcode/starcoder") - - @cached_property - def gpt35(self): - return self.__load_openai_model("gpt-3.5-turbo") - - @cached_property - def gpt350613(self): - return self.__load_openai_model("gpt-3.5-turbo-0613") - - @cached_property - def gpt3516k(self): - return self.__load_openai_model("gpt-3.5-turbo-16k") - - @cached_property - def gpt4(self): - return self.__load_openai_model("gpt-4") - - @cached_property - def ggml(self): - return GGML(system_message=self.system_message) - - def __model_from_name(self, model_name: str): - if model_name == "starcoder": - return self.starcoder - elif model_name == "gpt-3.5-turbo": - return self.gpt35 - elif model_name == "gpt-3.5-turbo-16k": - return self.gpt3516k - elif model_name == "gpt-4": - return self.gpt4 - elif model_name == "claude-2": - return self.claude2 - elif model_name == "ggml": - return self.ggml - else: - raise Exception(f"Unknown model {model_name}") - @property def default(self): - default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt4 + return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4") class ContinueSDK(AbstractContinueSDK): @@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK): )) sdk.models = await Models.create(sdk) + await sdk.models.start() return sdk @property @@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + async def get_api_key(self, env_var: str) -> str: + # TODO support error prompt dynamically set on env_var + return await self.ide.getUserSecret(env_var) + async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..5641d8a9 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,6 +9,14 @@ from pydantic import BaseModel class LLM(ABC): system_message: Union[str, None] = None + async def start(self): + """Start the connection to the LLM.""" + raise NotImplementedError + + async def stop(self): + """Stop the connection to the LLM.""" + raise NotImplementedError + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..0067ce3a 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,16 +9,18 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): - api_key: str default_model: str async_client: AsyncAnthropic - def __init__(self, api_key: str, default_model: str, system_message: str = None): - self.api_key = api_key + def __init__(self, default_model: str, system_message: str = None): self.default_model = default_model self.system_message = system_message - self.async_client = AsyncAnthropic(api_key=api_key) + async def start(self): + self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY")) + + async def stop(self): + pass @cached_property def name(self): diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..ef771a2e 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -35,6 +35,7 @@ class GGML(LLM): messages = compile_chat_messages( self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + # TODO move to single self.session variable (proxy setting etc) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": messages, diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..39b54f0f 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): - api_key: str model: str - def __init__(self, api_key: str, model: str, system_message: str = None): - self.api_key = api_key + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this + async def start(self): + self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN")) + def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d..1a48fa86 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,24 +10,27 @@ from ...core.config import AzureInfo class OpenAI(LLM): - api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): - self.api_key = api_key + def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): self.default_model = default_model self.system_message = system_message self.azure_info = azure_info self.write_log = write_log - openai.api_key = api_key + async def start(self): + self.api_key = await sdk.get_api_key("OPENAI_API_KEY") + openai.api_key = self.api_key # Using an Azure OpenAI deployment - if azure_info is not None: + if self.azure_info is not None: openai.api_type = "azure" openai.api_base = azure_info.endpoint openai.api_version = azure_info.api_version + async def stop(self): + pass + @cached_property def name(self): return self.default_model diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..c0e2a403 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -30,6 +30,12 @@ class ProxyServer(LLM): self.name = default_model self.write_log = write_log + async def start(self): + pass + + async def stop(self): + pass + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.default_model} |