diff options
author | sestinj <sestinj@gmail.com> | 2023-07-15 15:10:48 -0700 |
---|---|---|
committer | sestinj <sestinj@gmail.com> | 2023-07-15 15:10:48 -0700 |
commit | 152e3ae0d5455e621bd37cf7962478e9fa03f5eb (patch) | |
tree | 455c1fffa360aed894d8f745f810af247ddfdf6a /continuedev/src/continuedev/core | |
parent | abe77c56abd7aea66fa85bd1257f76dc2d435a15 (diff) | |
parent | 48e5c8001e897eb37493357087410ee8f98217fa (diff) | |
download | sncontinue-152e3ae0d5455e621bd37cf7962478e9fa03f5eb.tar.gz sncontinue-152e3ae0d5455e621bd37cf7962478e9fa03f5eb.tar.bz2 sncontinue-152e3ae0d5455e621bd37cf7962478e9fa03f5eb.zip |
Merge remote origin main
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 59 |
2 files changed, 49 insertions, 20 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 82439f49..0696c360 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -50,6 +50,8 @@ class Autopilot(ContinueBaseModel): full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] + continue_sdk: ContinueSDK = None + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -57,9 +59,11 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @cached_property - def continue_sdk(self) -> ContinueSDK: - return ContinueSDK(self) + @classmethod + async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": + autopilot = cls(ide=ide, policy=policy) + autopilot.continue_sdk = await ContinueSDK.create(autopilot) + return autopilot class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index aa2d8892..d73561d2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@ import asyncio from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union import os from ..steps.core.core import DefaultModelEditCodeStep @@ -13,7 +13,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, History, Step, ChatMessage from ..steps.core.core import * from ..libs.llm.proxy_server import ProxyServer @@ -22,26 +22,46 @@ class Autopilot: pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { + "openai": "OPENAI_API_KEY", + "hf_inference_api": "HUGGING_FACE_TOKEN", + "anthropic": "ANTHROPIC_API_KEY" +} + + class Models: - def __init__(self, sdk: "ContinueSDK"): + provider_keys: Dict[ModelProvider, str] = {} + model_providers: List[ModelProvider] + + def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk + self.model_providers = model_providers + + @classmethod + async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + models = Models(sdk, with_providers) + for provider in with_providers: + if provider in MODEL_PROVIDER_TO_ENV_VAR: + env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] + models.provider_keys[provider] = await sdk.get_user_secret( + env_var, f'Please add your {env_var} to the .env file') + + return models def __load_openai_model(self, model: str) -> OpenAI: - async def load_openai_model(): - api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) - return asyncio.get_event_loop().run_until_complete(load_openai_model()) + api_key = self.provider_keys["openai"] + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, model) + return OpenAI(api_key=api_key, default_model=model) + + def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: + api_key = self.provider_keys["hf_inference_api"] + return HuggingFaceInferenceAPI(api_key=api_key, model=model) @cached_property def starcoder(self): - async def load_starcoder(): - api_key = await self.sdk.get_user_secret( - 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') - return HuggingFaceInferenceAPI(api_key=api_key) - return asyncio.get_event_loop().run_until_complete(load_starcoder()) + return self.__load_hf_inference_api_model("bigcode/starcoder") @cached_property def gpt35(self): @@ -74,7 +94,7 @@ class Models: @property def default(self): default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + return self.__model_from_name(default_model) if default_model is not None else self.gpt4 class ContinueSDK(AbstractContinueSDK): @@ -87,10 +107,15 @@ class ContinueSDK(AbstractContinueSDK): def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot - self.models = Models(self) self.context = autopilot.context self.config = self._load_config() + @classmethod + async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + sdk.models = await Models.create(sdk) + return sdk + config: ContinueConfig def _load_config(self) -> ContinueConfig: |