diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 | 
| commit | 044b7caa6b26a5d78ae52faa0ae675abc8c4e161 (patch) | |
| tree | 5727da03069cc6bf83b9dd385003f615124e6fe8 /continuedev/src | |
| parent | 631d141dbd26edb0de3e0e3ed194dbfd3641059f (diff) | |
| download | sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.gz sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.bz2 sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.zip | |
feat: :sparkles: select model from dropdown
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/models.py | 34 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 43 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 24 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 83 | 
8 files changed, 179 insertions, 19 deletions
| diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index e4610d36..a8c97622 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,8 +1,15 @@ -from typing import Optional +from typing import List, Optional  from pydantic import BaseModel  from ..libs.llm import LLM +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.ggml import GGML +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ..libs.llm.ollama import Ollama +from ..libs.llm.openai import OpenAI +from ..libs.llm.replicate import ReplicateLLM +from ..libs.llm.together import TogetherLLM  class ContinueSDK(BaseModel): @@ -18,6 +25,29 @@ ALL_MODEL_ROLES = [      "chat",  ] +MODEL_CLASSES = { +    cls.__name__: cls +    for cls in [ +        OpenAI, +        MaybeProxyOpenAI, +        GGML, +        TogetherLLM, +        AnthropicLLM, +        ReplicateLLM, +        Ollama, +    ] +} + +MODEL_MODULE_NAMES = { +    "OpenAI": "openai", +    "MaybeProxyOpenAI": "maybe_proxy_openai", +    "GGML": "ggml", +    "TogetherLLM": "together", +    "AnthropicLLM": "anthropic", +    "ReplicateLLM": "replicate", +    "Ollama": "ollama", +} +  class Models(BaseModel):      """Main class that holds the current model configuration""" @@ -29,6 +59,8 @@ class Models(BaseModel):      edit: Optional[LLM] = None      chat: Optional[LLM] = None +    unused: List[LLM] = [] +      # TODO namespace these away to not confuse readers,      # or split Models into ModelsConfig, which gets turned into Models      sdk: ContinueSDK = None diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 34c3ab74..b4548ff2 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -17,7 +17,7 @@ class GGML(LLM):      # this is model-specific      max_context_length: int = 2048      server_url: str = "http://localhost:8000" -    verify_ssl: bool = True +    verify_ssl: Optional[bool] = None      requires_write_log = True diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 76331a28..7abd268d 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List +from typing import Any, Coroutine, Dict, Generator, List, Optional  import aiohttp  import requests @@ -15,7 +15,7 @@ class HuggingFaceInferenceAPI(LLM):      hf_token: str      max_context_length: int = 2048 -    verify_ssl: bool = True +    verify_ssl: Optional[bool] = None      _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 276cc290..48e773a3 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -49,7 +49,7 @@ class OpenAI(LLM):      api_key: str      model: str      openai_server_info: Optional[OpenAIServerInfo] = None -    verify_ssl: bool = True +    verify_ssl: Optional[bool] = None      ca_bundle_path: Optional[str] = None      requires_write_log = True @@ -73,7 +73,7 @@ class OpenAI(LLM):              if self.openai_server_info.api_version is not None:                  openai.api_version = self.openai_server_info.api_version -        if self.verify_ssl is False: +        if self.verify_ssl is not None and self.verify_ssl is False:              openai.verify_ssl_certs = False          openai.ca_bundle_path = self.ca_bundle_path or certifi.where() diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 44f5030c..9a28de2d 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -1,5 +1,5 @@  import json -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union  import aiohttp @@ -14,7 +14,7 @@ class TogetherLLM(LLM):      model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"      max_context_length: int = 2048      base_url: str = "https://api.together.xyz" -    verify_ssl: bool = True +    verify_ssl: Optional[bool] = None      _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 17ce27ec..7bdffc8e 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -1,5 +1,5 @@  import threading -from typing import Dict, List +from typing import Any, Dict, List  import redbaron @@ -53,6 +53,47 @@ def edit_config_property(key_path: List[str], value: redbaron.RedBaron):              file.write(red.dumps()) +def add_config_import(line: str): +    with edit_lock: +        red = load_red() +        # check if the import already exists +        for node in red: +            if node.type == "import" and node.dumps() == line: +                return +        # if it doesn't exist, add it +        red.insert(1, line) + +        with open(getConfigFilePath(), "w") as file: +            file.write(red.dumps()) + + +filtered_attrs = { +    "requires_api_key", +    "requires_unique_id", +    "requires_write_log", +    "class_name", +    "name", +    "llm", +} + + +def display_val(v: Any): +    if isinstance(v, str): +        return f'"{v}"' +    return str(v) + + +def display_llm_class(llm): +    args = ", ".join( +        [ +            f"{k}={display_val(v)}" +            for k, v in llm.dict().items() +            if k not in filtered_attrs and v is not None +        ] +    ) +    return f"{llm.__class__.__name__}({args})" + +  def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:      args = [f"{key}={value}" for key, value in args.items()]      return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0] diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py new file mode 100644 index 00000000..1c50c714 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -0,0 +1,24 @@ +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...libs.util.paths import getConfigFilePath + +MODEL_CLASS_TO_MESSAGE = { +    "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", +    "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", +    "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", +    "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.", +    "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).", +    "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", +    "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", +} + + +class SetupModelStep(Step): +    model_class: str +    name: str = "Setup model in config.py" + +    async def run(self, sdk: ContinueSDK): +        await sdk.ide.setFileOpen(getConfigFilePath()) +        self.description = MODEL_CLASS_TO_MESSAGE.get( +            self.model_class, "Please finish setting up this model in `config.py`" +        ) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index bdcaad47..55a5f3b4 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState  from uvicorn.main import Server  from ..core.main import ContextItem +from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES  from ..libs.util.create_async_task import create_async_task  from ..libs.util.edit_config import ( +    add_config_import,      create_float_node,      create_obj_node,      create_string_node, +    display_llm_class,  )  from ..libs.util.logging import logger  from ..libs.util.queue import AsyncSubscriptionQueue  from ..libs.util.telemetry import posthog_logger  from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.setup_model import SetupModelStep  from .gui_protocol import AbstractGUIProtocolServer  from .session_manager import Session, session_manager @@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          )      def set_model_for_role(self, role: str, model_class: str, model: Any): -        prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) -        if prev_model is not None: -            prev_model.update(model) -        self.session.autopilot.continue_sdk.models.__setattr__(role, model) -        create_async_task( -            self.session.autopilot.set_config_attr( -                ["models", role], create_obj_node(model_class, {**model}) -            ), -            self.on_error, -        ) +        models = self.session.autopilot.continue_sdk.config.models +        unused_models = models.unused +        if role == "*": + +            async def async_stuff(): +                # Clear all of the models and store them in unused_models +                # NOTE: There will be duplicates +                for role in ALL_MODEL_ROLES: +                    prev_model = models.__getattribute__(role) +                    models.__setattr__(role, None) +                    if prev_model is not None: +                        exists = False +                        for other in unused_models: +                            if display_llm_class(prev_model) == display_llm_class( +                                other +                            ): +                                exists = True +                                break +                        if not exists: +                            unused_models.append(prev_model) + +                # Replace default with either new one or existing from unused_models +                for unused_model in unused_models: +                    if model_class == unused_model.__class__.__name__ and ( +                        "model" not in model or model["model"] == unused_model.model +                    ): +                        models.default = unused_model + +                # Set and start the default model if didn't already exist from unused +                if models.default is None: +                    models.default = MODEL_CLASSES[model_class](**model) +                    await self.session.autopilot.continue_sdk.start_model( +                        models.default +                    ) +                    await self.session.autopilot.continue_sdk.run_step( +                        SetupModelStep(model_class=model_class) +                    ) + +                models_args = {} + +                for role in ALL_MODEL_ROLES: +                    val = models.__getattribute__(role) +                    if val is None: +                        continue  # no pun intended + +                    models_args[role] = display_llm_class(val) + +                JOINER = ", " +                models_args[ +                    "unused" +                ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]" + +                await self.session.autopilot.set_config_attr( +                    ["models"], +                    create_obj_node("Models", models_args), +                ) + +                add_config_import( +                    f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" +                ) + +                for role in ALL_MODEL_ROLES: +                    if role != "default": +                        models.__setattr__(role, models.default) + +            create_async_task(async_stuff(), self.on_error) +        else: +            # TODO +            pass      def save_context_group(self, title: str, context_items: List[ContextItem]):          create_async_task( | 
