diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-15 14:30:11 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-15 14:30:11 -0700 | 
| commit | 925c3e0ef45d9eb01a8f6c1efd239fa011492bd2 (patch) | |
| tree | cd07a7b0acee3ffbaa3570483bd713032e0de341 /continuedev/src | |
| parent | 6b3d20c943c0c1417b437ad475019bae729103ed (diff) | |
| download | sncontinue-925c3e0ef45d9eb01a8f6c1efd239fa011492bd2.tar.gz sncontinue-925c3e0ef45d9eb01a8f6c1efd239fa011492bd2.tar.bz2 sncontinue-925c3e0ef45d9eb01a8f6c1efd239fa011492bd2.zip | |
ctrl shortcuts on windows, load models immediately
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 59 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 6 | 
5 files changed, 59 insertions, 26 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 82439f49..0696c360 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -50,6 +50,8 @@ class Autopilot(ContinueBaseModel):      full_state: Union[FullState, None] = None      _on_update_callbacks: List[Callable[[FullState], None]] = [] +    continue_sdk: ContinueSDK = None +      _active: bool = False      _should_halt: bool = False      _main_user_input_queue: List[str] = [] @@ -57,9 +59,11 @@ class Autopilot(ContinueBaseModel):      _user_input_queue = AsyncSubscriptionQueue()      _retry_queue = AsyncSubscriptionQueue() -    @cached_property -    def continue_sdk(self) -> ContinueSDK: -        return ContinueSDK(self) +    @classmethod +    async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": +        autopilot = cls(ide=ide, policy=policy) +        autopilot.continue_sdk = await ContinueSDK.create(autopilot) +        return autopilot      class Config:          arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index aa2d8892..d73561d2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@  import asyncio  from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union  import os  from ..steps.core.core import DefaultModelEditCodeStep @@ -13,7 +13,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  from ..libs.llm.openai import OpenAI  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, History, Step, ChatMessage  from ..steps.core.core import *  from ..libs.llm.proxy_server import ProxyServer @@ -22,26 +22,46 @@ class Autopilot:      pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { +    "openai": "OPENAI_API_KEY", +    "hf_inference_api": "HUGGING_FACE_TOKEN", +    "anthropic": "ANTHROPIC_API_KEY" +} + +  class Models: -    def __init__(self, sdk: "ContinueSDK"): +    provider_keys: Dict[ModelProvider, str] = {} +    model_providers: List[ModelProvider] + +    def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):          self.sdk = sdk +        self.model_providers = model_providers + +    @classmethod +    async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": +        models = Models(sdk, with_providers) +        for provider in with_providers: +            if provider in MODEL_PROVIDER_TO_ENV_VAR: +                env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] +                models.provider_keys[provider] = await sdk.get_user_secret( +                    env_var, f'Please add your {env_var} to the .env file') + +        return models      def __load_openai_model(self, model: str) -> OpenAI: -        async def load_openai_model(): -            api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') -            if api_key == "": -                return ProxyServer(self.sdk.ide.unique_id, model) -            return OpenAI(api_key=api_key, default_model=model) -        return asyncio.get_event_loop().run_until_complete(load_openai_model()) +        api_key = self.provider_keys["openai"] +        if api_key == "": +            return ProxyServer(self.sdk.ide.unique_id, model) +        return OpenAI(api_key=api_key, default_model=model) + +    def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: +        api_key = self.provider_keys["hf_inference_api"] +        return HuggingFaceInferenceAPI(api_key=api_key, model=model)      @cached_property      def starcoder(self): -        async def load_starcoder(): -            api_key = await self.sdk.get_user_secret( -                'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') -            return HuggingFaceInferenceAPI(api_key=api_key) -        return asyncio.get_event_loop().run_until_complete(load_starcoder()) +        return self.__load_hf_inference_api_model("bigcode/starcoder")      @cached_property      def gpt35(self): @@ -74,7 +94,7 @@ class Models:      @property      def default(self):          default_model = self.sdk.config.default_model -        return self.__model_from_name(default_model) if default_model is not None else self.gpt35 +        return self.__model_from_name(default_model) if default_model is not None else self.gpt4  class ContinueSDK(AbstractContinueSDK): @@ -87,10 +107,15 @@ class ContinueSDK(AbstractContinueSDK):      def __init__(self, autopilot: Autopilot):          self.ide = autopilot.ide          self.__autopilot = autopilot -        self.models = Models(self)          self.context = autopilot.context          self.config = self._load_config() +    @classmethod +    async def create(cls, autopilot: Autopilot) -> "ContinueSDK": +        sdk = ContinueSDK(autopilot) +        sdk.models = await Models.create(sdk) +        return sdk +      config: ContinueConfig      def _load_config(self) -> ContinueConfig: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 1586c620..803ba122 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -9,7 +9,11 @@ DEFAULT_MAX_TIME = 120.  class HuggingFaceInferenceAPI(LLM):      api_key: str -    model: str = "bigcode/starcoder" +    model: str + +    def __init__(self, api_key: str, model: str): +        self.api_key = api_key +        self.model = model      def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):          """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 7875c94d..77b13483 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -227,8 +227,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          })      async def getSessionId(self): -        session_id = self.session_manager.new_session( -            self, self.session_id).session_id +        session_id = (await self.session_manager.new_session( +            self, self.session_id)).session_id          await self._send_json("getSessionId", {              "sessionId": session_id          }) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index fb8ac386..6d109ca6 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -53,18 +53,18 @@ class SessionManager:              session_files = os.listdir(sessions_folder)              if f"{session_id}.json" in session_files and session_id in self.registered_ides:                  if self.registered_ides[session_id].session_id is not None: -                    return self.new_session(self.registered_ides[session_id], session_id=session_id) +                    return await self.new_session(self.registered_ides[session_id], session_id=session_id)              raise KeyError("Session ID not recognized", session_id)          return self.sessions[session_id] -    def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: +    async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:          full_state = None          if session_id is not None and os.path.exists(getSessionFilePath(session_id)):              with open(getSessionFilePath(session_id), "r") as f:                  full_state = FullState(**json.load(f)) -        autopilot = DemoAutopilot( +        autopilot = await DemoAutopilot.create(              policy=DemoPolicy(), ide=ide, full_state=full_state)          session_id = session_id or str(uuid4())          ide.session_id = session_id | 
