summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorLuna <git@l4.pm>2023-07-28 00:41:05 -0300
committerLuna <git@l4.pm>2023-07-28 00:41:05 -0300
commit902ac7f67c2bb7b9e7c17ed15582414db42b5c52 (patch)
tree3985781f18e49566c2fb0a214cc31d386843d8c9 /continuedev
parentd9026e72caa8ff94aa066f16cef677e5de76af07 (diff)
downloadsncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.gz
sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.bz2
sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.zip
make Config receive LLM objects
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/config.py5
-rw-r--r--continuedev/src/continuedev/core/sdk.py66
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py1
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py7
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py13
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
8 files changed, 43 insertions, 73 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index cb9c8977..23f4fe65 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -2,6 +2,7 @@ import json
import os
from .main import Step
from .context import ContextProvider
+from ..libs.llm.openai import OpenAI
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type, Union
import yaml
@@ -38,8 +39,7 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ llm: LLM = OpenAI(default_model="gpt-4")
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -49,7 +49,6 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- azure_openai_info: Optional[AzureInfo] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 992bc1cf..92a72b23 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -63,18 +63,8 @@ class Models:
self.system_message = sdk.config.system_message
@classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
+ async def create(cls, sdk: "ContinueSDK") -> "Models":
+ return self.default
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
@@ -90,54 +80,9 @@ class Models:
api_key = self.provider_keys["anthropic"]
return AnthropicLLM(api_key, model, self.system_message)
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
@property
def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
+ return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4")
class ContinueSDK(AbstractContinueSDK):
@@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK):
))
sdk.models = await Models.create(sdk)
+ await sdk.models.start()
return sdk
@property
@@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
+ async def get_api_key(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
+ return await self.ide.getUserSecret(env_var)
+
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2766db4b..5641d8a9 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -9,6 +9,14 @@ from pydantic import BaseModel
class LLM(ABC):
system_message: Union[str, None] = None
+ async def start(self):
+ """Start the connection to the LLM."""
+ raise NotImplementedError
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ raise NotImplementedError
+
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 625d4e57..0067ce3a 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -9,16 +9,18 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_
class AnthropicLLM(LLM):
- api_key: str
default_model: str
async_client: AsyncAnthropic
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
+ def __init__(self, default_model: str, system_message: str = None):
self.default_model = default_model
self.system_message = system_message
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self):
+ self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY"))
+
+ async def stop(self):
+ pass
@cached_property
def name(self):
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 4889a556..ef771a2e 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -35,6 +35,7 @@ class GGML(LLM):
messages = compile_chat_messages(
self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ # TODO move to single self.session variable (proxy setting etc)
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
"messages": messages,
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 36f03270..39b54f0f 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- api_key: str
model: str
- def __init__(self, api_key: str, model: str, system_message: str = None):
- self.api_key = api_key
+ def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self):
+ self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN"))
+
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index a0773c1d..1a48fa86 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -10,24 +10,27 @@ from ...core.config import AzureInfo
class OpenAI(LLM):
- api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
- self.api_key = api_key
+ def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
self.default_model = default_model
self.system_message = system_message
self.azure_info = azure_info
self.write_log = write_log
- openai.api_key = api_key
+ async def start(self):
+ self.api_key = await sdk.get_api_key("OPENAI_API_KEY")
+ openai.api_key = self.api_key
# Using an Azure OpenAI deployment
- if azure_info is not None:
+ if self.azure_info is not None:
openai.api_type = "azure"
openai.api_base = azure_info.endpoint
openai.api_version = azure_info.api_version
+ async def stop(self):
+ pass
+
@cached_property
def name(self):
return self.default_model
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index f9e3fa01..c0e2a403 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -30,6 +30,12 @@ class ProxyServer(LLM):
self.name = default_model
self.write_log = write_log
+ async def start(self):
+ pass
+
+ async def stop(self):
+ pass
+
@property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.default_model}