diff options
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 75 |
2 files changed, 40 insertions, 39 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 23f4fe65..6957ae5e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -39,7 +39,9 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - llm: LLM = OpenAI(default_model="gpt-4") + models: Models = Models( + default=MaybeProxyOpenAI(model="gpt-4"), + ) temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 92a72b23..183518ac 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -8,15 +8,11 @@ from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI -from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import AnthropicLLM -from ..libs.llm.ggml import GGML +from ..libs.llm import LLM from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage from ..plugins.steps.core.core import * -from ..libs.llm.proxy_server import ProxyServer from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath @@ -25,18 +21,12 @@ class Autopilot: pass -ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] -MODEL_PROVIDER_TO_ENV_VAR = { - "openai": "OPENAI_API_KEY", - "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY", -} - - class Models: - provider_keys: Dict[ModelProvider, str] = {} - model_providers: List[ModelProvider] - system_message: str + """Main class that holds the current model configuration""" + default: LLM + small: Optional[LLM] = None + medium: Optional[LLM] = None + large: Optional[LLM] = None """ Better to have sdk.llm.stream_chat(messages, model="claude-2"). @@ -57,32 +47,41 @@ class Models: '''depending on the model, return the single prompt string''' """ - def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): - self.sdk = sdk - self.model_providers = model_providers + def __init__(self, *, default, small=None, medium=None, large=None, custom=None): + self.default = default + self.small = small + self.medium = medium + self.large = large self.system_message = sdk.config.system_message - @classmethod - async def create(cls, sdk: "ContinueSDK") -> "Models": - return self.default + async def start(sdk: "ContinueSDK"): + self.sdk = sdk + await self.default.start() + if self.small: + await self.small.start() + else: + self.small = self.default - def __load_openai_model(self, model: str) -> OpenAI: - api_key = self.provider_keys["openai"] - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) - return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + if self.medium: + await self.medium.start() + else: + self.medium = self.default - def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: - api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) + if self.large: + await self.large.start() + else: + self.large = self.default - def __load_anthropic_model(self, model: str) -> AnthropicLLM: - api_key = self.provider_keys["anthropic"] - return AnthropicLLM(api_key, model, self.system_message) + async def stop(sdk: "ContinueSDK"): + await self.default.stop() + if self.small: + await self.small.stop() - @property - def default(self): - return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4") + if self.medium: + await self.medium.stop() + + if self.large: + await self.large.stop() class ContinueSDK(AbstractContinueSDK): @@ -118,8 +117,8 @@ class ContinueSDK(AbstractContinueSDK): active=False )) - sdk.models = await Models.create(sdk) - await sdk.models.start() + sdk.models = sdk.config.models + await sdk.models.start(sdk) return sdk @property |