diff options
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 66 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 13 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 | 
8 files changed, 43 insertions, 73 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..23f4fe65 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,6 +2,7 @@ import json  import os  from .main import Step  from .context import ContextProvider +from ..libs.llm.openai import OpenAI  from pydantic import BaseModel, validator  from typing import List, Literal, Optional, Dict, Type, Union  import yaml @@ -38,8 +39,7 @@ class ContinueConfig(BaseModel):      steps_on_startup: List[Step] = []      disallowed_steps: Optional[List[str]] = []      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4", "claude-2", "ggml"] = 'gpt-4' +    llm: LLM = OpenAI(default_model="gpt-4")      temperature: Optional[float] = 0.5      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test", @@ -49,7 +49,6 @@ class ContinueConfig(BaseModel):      slash_commands: Optional[List[SlashCommand]] = []      on_traceback: Optional[List[OnTracebackSteps]] = []      system_message: Optional[str] = None -    azure_openai_info: Optional[AzureInfo] = None      context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 992bc1cf..92a72b23 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -63,18 +63,8 @@ class Models:          self.system_message = sdk.config.system_message      @classmethod -    async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": -        if sdk.config.default_model == "claude-2": -            with_providers.append("anthropic") - -        models = Models(sdk, with_providers) -        for provider in with_providers: -            if provider in MODEL_PROVIDER_TO_ENV_VAR: -                env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] -                models.provider_keys[provider] = await sdk.get_user_secret( -                    env_var, f'Please add your {env_var} to the .env file') - -        return models +    async def create(cls, sdk: "ContinueSDK") -> "Models": +        return self.default      def __load_openai_model(self, model: str) -> OpenAI:          api_key = self.provider_keys["openai"] @@ -90,54 +80,9 @@ class Models:          api_key = self.provider_keys["anthropic"]          return AnthropicLLM(api_key, model, self.system_message) -    @cached_property -    def claude2(self): -        return self.__load_anthropic_model("claude-2") - -    @cached_property -    def starcoder(self): -        return self.__load_hf_inference_api_model("bigcode/starcoder") - -    @cached_property -    def gpt35(self): -        return self.__load_openai_model("gpt-3.5-turbo") - -    @cached_property -    def gpt350613(self): -        return self.__load_openai_model("gpt-3.5-turbo-0613") - -    @cached_property -    def gpt3516k(self): -        return self.__load_openai_model("gpt-3.5-turbo-16k") - -    @cached_property -    def gpt4(self): -        return self.__load_openai_model("gpt-4") - -    @cached_property -    def ggml(self): -        return GGML(system_message=self.system_message) - -    def __model_from_name(self, model_name: str): -        if model_name == "starcoder": -            return self.starcoder -        elif model_name == "gpt-3.5-turbo": -            return self.gpt35 -        elif model_name == "gpt-3.5-turbo-16k": -            return self.gpt3516k -        elif model_name == "gpt-4": -            return self.gpt4 -        elif model_name == "claude-2": -            return self.claude2 -        elif model_name == "ggml": -            return self.ggml -        else: -            raise Exception(f"Unknown model {model_name}") -      @property      def default(self): -        default_model = self.sdk.config.default_model -        return self.__model_from_name(default_model) if default_model is not None else self.gpt4 +        return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4")  class ContinueSDK(AbstractContinueSDK): @@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK):              ))          sdk.models = await Models.create(sdk) +        await sdk.models.start()          return sdk      @property @@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK):          path = await self._ensure_absolute_path(path)          return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) +    async def get_api_key(self, env_var: str) -> str: +        # TODO support error prompt dynamically set on env_var +        return await self.ide.getUserSecret(env_var) +      async def get_user_secret(self, env_var: str, prompt: str) -> str:          return await self.ide.getUserSecret(env_var) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..5641d8a9 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,6 +9,14 @@ from pydantic import BaseModel  class LLM(ABC):      system_message: Union[str, None] = None +    async def start(self): +        """Start the connection to the LLM.""" +        raise NotImplementedError + +    async def stop(self): +        """Stop the connection to the LLM.""" +        raise NotImplementedError +      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          """Return the completion of the text with the given temperature."""          raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..0067ce3a 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,16 +9,18 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_  class AnthropicLLM(LLM): -    api_key: str      default_model: str      async_client: AsyncAnthropic -    def __init__(self, api_key: str, default_model: str, system_message: str = None): -        self.api_key = api_key +    def __init__(self, default_model: str, system_message: str = None):          self.default_model = default_model          self.system_message = system_message -        self.async_client = AsyncAnthropic(api_key=api_key) +    async def start(self): +        self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY")) + +    async def stop(self): +        pass      @cached_property      def name(self): diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..ef771a2e 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -35,6 +35,7 @@ class GGML(LLM):          messages = compile_chat_messages(              self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) +        # TODO move to single self.session variable (proxy setting etc)          async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/completions", json={                  "messages": messages, diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..39b54f0f 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120.  class HuggingFaceInferenceAPI(LLM): -    api_key: str      model: str -    def __init__(self, api_key: str, model: str, system_message: str = None): -        self.api_key = api_key +    def __init__(self, model: str, system_message: str = None):          self.model = model          self.system_message = system_message  # TODO: Nothing being done with this +    async def start(self): +        self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN")) +      def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d..1a48fa86 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,24 +10,27 @@ from ...core.config import AzureInfo  class OpenAI(LLM): -    api_key: str      default_model: str -    def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): -        self.api_key = api_key +    def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):          self.default_model = default_model          self.system_message = system_message          self.azure_info = azure_info          self.write_log = write_log -        openai.api_key = api_key +    async def start(self): +        self.api_key = await sdk.get_api_key("OPENAI_API_KEY") +        openai.api_key = self.api_key          # Using an Azure OpenAI deployment -        if azure_info is not None: +        if self.azure_info is not None:              openai.api_type = "azure"              openai.api_base = azure_info.endpoint              openai.api_version = azure_info.api_version +    async def stop(self): +        pass +      @cached_property      def name(self):          return self.default_model diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..c0e2a403 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -30,6 +30,12 @@ class ProxyServer(LLM):          self.name = default_model          self.write_log = write_log +    async def start(self): +        pass + +    async def stop(self): +        pass +      @property      def default_args(self):          return {**DEFAULT_ARGS, "model": self.default_model} | 
