From 2b651d2504638ea9db97ba612f702356e38a805e Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jul 2023 23:50:47 -0300 Subject: make Models fetch LLM secret field declaratively --- continuedev/src/continuedev/core/sdk.py | 14 ++++++++++---- continuedev/src/continuedev/libs/llm/__init__.py | 3 ++- continuedev/src/continuedev/libs/llm/anthropic.py | 5 +++-- continuedev/src/continuedev/libs/llm/hf_inference_api.py | 5 +++-- continuedev/src/continuedev/libs/llm/openai.py | 5 +++-- continuedev/src/continuedev/libs/llm/utils.py | 1 + continuedev/src/continuedev/libs/util/count_tokens.py | 4 ++++ 7 files changed, 26 insertions(+), 11 deletions(-) (limited to 'continuedev') diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 183518ac..784f8ed1 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -54,21 +54,27 @@ class Models: self.large = large self.system_message = sdk.config.system_message + async def _start(llm: LLM): + kwargs = {} + if llm.required_api_key: + kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key) + await llm.start(**kwargs) + async def start(sdk: "ContinueSDK"): self.sdk = sdk - await self.default.start() + await self._start(self.default) if self.small: - await self.small.start() + await self._start(self.small) else: self.small = self.default if self.medium: - await self.medium.start() + await self._start(self.medium) else: self.medium = self.default if self.large: - await self.large.start() + await self._start(self.large) else: self.large = self.default diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 5641d8a9..6ae3dd46 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -7,9 +7,10 @@ from pydantic import BaseModel class LLM(ABC): + required_api_key: Optional[str] = None system_message: Union[str, None] = None - async def start(self): + async def start(self, *, api_key: Optional[str] = None): """Start the connection to the LLM.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 0067ce3a..846a2450 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,6 +9,7 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): + required_api_key: str = "ANTHROPIC_API_KEY" default_model: str async_client: AsyncAnthropic @@ -16,8 +17,8 @@ class AnthropicLLM(LLM): self.default_model = default_model self.system_message = system_message - async def start(self): - self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY")) + async def start(self, *, api_key): + self.async_client = AsyncAnthropic(api_key=api_key) async def stop(self): pass diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 39b54f0f..06d37596 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): + required_api_key: str = "HUGGING_FACE_TOKEN" model: str def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self): - self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN")) + async def start(self, *, api_key): + self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 1a48fa86..c8de90a8 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,6 +10,7 @@ from ...core.config import AzureInfo class OpenAI(LLM): + required_api_key = "OPENAI_API_KEY" default_model: str def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): @@ -18,8 +19,8 @@ class OpenAI(LLM): self.azure_info = azure_info self.write_log = write_log - async def start(self): - self.api_key = await sdk.get_api_key("OPENAI_API_KEY") + async def start(self, *, api_key): + self.api_key = api_key openai.api_key = self.api_key # Using an Azure OpenAI deployment diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py index 76240d4e..4ea45b7b 100644 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ b/continuedev/src/continuedev/libs/llm/utils.py @@ -5,6 +5,7 @@ gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") def count_tokens(text: str) -> int: return len(gpt2_tokenizer.encode(text)) +# TODO move this to LLM class itself (especially as prices may change in the future) prices = { # All prices are per 1k tokens "fine-tune-train": { diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index c58ae499..f6c7cb00 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -4,6 +4,10 @@ from ...core.main import ChatMessage from .templating import render_templated_string import tiktoken +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA) aliases = { "ggml": "gpt-3.5-turbo", "claude-2": "gpt-3.5-turbo", -- cgit v1.2.3-70-g09d2