summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorLuna <git@l4.pm>2023-07-28 23:50:47 -0300
committerLuna <git@l4.pm>2023-07-28 23:50:47 -0300
commit2b651d2504638ea9db97ba612f702356e38a805e (patch)
treeb2564466ddcfb1496e92a6af7329d491f74292f8 /continuedev
parentcde2cc05a75f1ae98d0ef95f8495e52ee3c6f163 (diff)
downloadsncontinue-2b651d2504638ea9db97ba612f702356e38a805e.tar.gz
sncontinue-2b651d2504638ea9db97ba612f702356e38a805e.tar.bz2
sncontinue-2b651d2504638ea9db97ba612f702356e38a805e.zip
make Models fetch LLM secret field declaratively
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/sdk.py14
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py1
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
7 files changed, 26 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 183518ac..784f8ed1 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -54,21 +54,27 @@ class Models:
self.large = large
self.system_message = sdk.config.system_message
+ async def _start(llm: LLM):
+ kwargs = {}
+ if llm.required_api_key:
+ kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key)
+ await llm.start(**kwargs)
+
async def start(sdk: "ContinueSDK"):
self.sdk = sdk
- await self.default.start()
+ await self._start(self.default)
if self.small:
- await self.small.start()
+ await self._start(self.small)
else:
self.small = self.default
if self.medium:
- await self.medium.start()
+ await self._start(self.medium)
else:
self.medium = self.default
if self.large:
- await self.large.start()
+ await self._start(self.large)
else:
self.large = self.default
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 5641d8a9..6ae3dd46 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -7,9 +7,10 @@ from pydantic import BaseModel
class LLM(ABC):
+ required_api_key: Optional[str] = None
system_message: Union[str, None] = None
- async def start(self):
+ async def start(self, *, api_key: Optional[str] = None):
"""Start the connection to the LLM."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 0067ce3a..846a2450 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -9,6 +9,7 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_
class AnthropicLLM(LLM):
+ required_api_key: str = "ANTHROPIC_API_KEY"
default_model: str
async_client: AsyncAnthropic
@@ -16,8 +17,8 @@ class AnthropicLLM(LLM):
self.default_model = default_model
self.system_message = system_message
- async def start(self):
- self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY"))
+ async def start(self, *, api_key):
+ self.async_client = AsyncAnthropic(api_key=api_key)
async def stop(self):
pass
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 39b54f0f..06d37596 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
+ required_api_key: str = "HUGGING_FACE_TOKEN"
model: str
def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
- async def start(self):
- self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN"))
+ async def start(self, *, api_key):
+ self.api_key = api_key
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 1a48fa86..c8de90a8 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -10,6 +10,7 @@ from ...core.config import AzureInfo
class OpenAI(LLM):
+ required_api_key = "OPENAI_API_KEY"
default_model: str
def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
@@ -18,8 +19,8 @@ class OpenAI(LLM):
self.azure_info = azure_info
self.write_log = write_log
- async def start(self):
- self.api_key = await sdk.get_api_key("OPENAI_API_KEY")
+ async def start(self, *, api_key):
+ self.api_key = api_key
openai.api_key = self.api_key
# Using an Azure OpenAI deployment
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
index 76240d4e..4ea45b7b 100644
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ b/continuedev/src/continuedev/libs/llm/utils.py
@@ -5,6 +5,7 @@ gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
def count_tokens(text: str) -> int:
return len(gpt2_tokenizer.encode(text))
+# TODO move this to LLM class itself (especially as prices may change in the future)
prices = {
# All prices are per 1k tokens
"fine-tune-train": {
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c58ae499..f6c7cb00 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -4,6 +4,10 @@ from ...core.main import ChatMessage
from .templating import render_templated_string
import tiktoken
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",