diff options
| author | Luna <git@l4.pm> | 2023-07-28 00:41:05 -0300 |
|---|---|---|
| committer | Luna <git@l4.pm> | 2023-07-28 00:41:05 -0300 |
| commit | 902ac7f67c2bb7b9e7c17ed15582414db42b5c52 (patch) | |
| tree | 3985781f18e49566c2fb0a214cc31d386843d8c9 /continuedev/src/continuedev/core | |
| parent | d9026e72caa8ff94aa066f16cef677e5de76af07 (diff) | |
| download | sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.gz sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.bz2 sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.zip | |
make Config receive LLM objects
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 66 |
2 files changed, 10 insertions, 61 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..23f4fe65 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,6 +2,7 @@ import json import os from .main import Step from .context import ContextProvider +from ..libs.llm.openai import OpenAI from pydantic import BaseModel, validator from typing import List, Literal, Optional, Dict, Type, Union import yaml @@ -38,8 +39,7 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4", "claude-2", "ggml"] = 'gpt-4' + llm: LLM = OpenAI(default_model="gpt-4") temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", @@ -49,7 +49,6 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - azure_openai_info: Optional[AzureInfo] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 992bc1cf..92a72b23 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -63,18 +63,8 @@ class Models: self.system_message = sdk.config.system_message @classmethod - async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": - if sdk.config.default_model == "claude-2": - with_providers.append("anthropic") - - models = Models(sdk, with_providers) - for provider in with_providers: - if provider in MODEL_PROVIDER_TO_ENV_VAR: - env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] - models.provider_keys[provider] = await sdk.get_user_secret( - env_var, f'Please add your {env_var} to the .env file') - - return models + async def create(cls, sdk: "ContinueSDK") -> "Models": + return self.default def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] @@ -90,54 +80,9 @@ class Models: api_key = self.provider_keys["anthropic"] return AnthropicLLM(api_key, model, self.system_message) - @cached_property - def claude2(self): - return self.__load_anthropic_model("claude-2") - - @cached_property - def starcoder(self): - return self.__load_hf_inference_api_model("bigcode/starcoder") - - @cached_property - def gpt35(self): - return self.__load_openai_model("gpt-3.5-turbo") - - @cached_property - def gpt350613(self): - return self.__load_openai_model("gpt-3.5-turbo-0613") - - @cached_property - def gpt3516k(self): - return self.__load_openai_model("gpt-3.5-turbo-16k") - - @cached_property - def gpt4(self): - return self.__load_openai_model("gpt-4") - - @cached_property - def ggml(self): - return GGML(system_message=self.system_message) - - def __model_from_name(self, model_name: str): - if model_name == "starcoder": - return self.starcoder - elif model_name == "gpt-3.5-turbo": - return self.gpt35 - elif model_name == "gpt-3.5-turbo-16k": - return self.gpt3516k - elif model_name == "gpt-4": - return self.gpt4 - elif model_name == "claude-2": - return self.claude2 - elif model_name == "ggml": - return self.ggml - else: - raise Exception(f"Unknown model {model_name}") - @property def default(self): - default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt4 + return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4") class ContinueSDK(AbstractContinueSDK): @@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK): )) sdk.models = await Models.create(sdk) + await sdk.models.start() return sdk @property @@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + async def get_api_key(self, env_var: str) -> str: + # TODO support error prompt dynamically set on env_var + return await self.ide.getUserSecret(env_var) + async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) |
