summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorLuna <git@l4.pm>2023-07-28 00:41:05 -0300
committerLuna <git@l4.pm>2023-07-28 00:41:05 -0300
commit902ac7f67c2bb7b9e7c17ed15582414db42b5c52 (patch)
tree3985781f18e49566c2fb0a214cc31d386843d8c9 /continuedev/src/continuedev/core
parentd9026e72caa8ff94aa066f16cef677e5de76af07 (diff)
downloadsncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.gz
sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.tar.bz2
sncontinue-902ac7f67c2bb7b9e7c17ed15582414db42b5c52.zip
make Config receive LLM objects
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/config.py5
-rw-r--r--continuedev/src/continuedev/core/sdk.py66
2 files changed, 10 insertions, 61 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index cb9c8977..23f4fe65 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -2,6 +2,7 @@ import json
import os
from .main import Step
from .context import ContextProvider
+from ..libs.llm.openai import OpenAI
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type, Union
import yaml
@@ -38,8 +39,7 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ llm: LLM = OpenAI(default_model="gpt-4")
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -49,7 +49,6 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- azure_openai_info: Optional[AzureInfo] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 992bc1cf..92a72b23 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -63,18 +63,8 @@ class Models:
self.system_message = sdk.config.system_message
@classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
+ async def create(cls, sdk: "ContinueSDK") -> "Models":
+ return self.default
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
@@ -90,54 +80,9 @@ class Models:
api_key = self.provider_keys["anthropic"]
return AnthropicLLM(api_key, model, self.system_message)
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
@property
def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
+ return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4")
class ContinueSDK(AbstractContinueSDK):
@@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK):
))
sdk.models = await Models.create(sdk)
+ await sdk.models.start()
return sdk
@property
@@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
+ async def get_api_key(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
+ return await self.ide.getUserSecret(env_var)
+
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)