diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-26 10:10:06 -0700 |
commit | 044b7caa6b26a5d78ae52faa0ae675abc8c4e161 (patch) | |
tree | 5727da03069cc6bf83b9dd385003f615124e6fe8 /continuedev/src | |
parent | 631d141dbd26edb0de3e0e3ed194dbfd3641059f (diff) | |
download | sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.gz sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.bz2 sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.zip |
feat: :sparkles: select model from dropdown
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 34 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 43 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 24 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 83 |
8 files changed, 179 insertions, 19 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index e4610d36..a8c97622 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,8 +1,15 @@ -from typing import Optional +from typing import List, Optional from pydantic import BaseModel from ..libs.llm import LLM +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.ggml import GGML +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ..libs.llm.ollama import Ollama +from ..libs.llm.openai import OpenAI +from ..libs.llm.replicate import ReplicateLLM +from ..libs.llm.together import TogetherLLM class ContinueSDK(BaseModel): @@ -18,6 +25,29 @@ ALL_MODEL_ROLES = [ "chat", ] +MODEL_CLASSES = { + cls.__name__: cls + for cls in [ + OpenAI, + MaybeProxyOpenAI, + GGML, + TogetherLLM, + AnthropicLLM, + ReplicateLLM, + Ollama, + ] +} + +MODEL_MODULE_NAMES = { + "OpenAI": "openai", + "MaybeProxyOpenAI": "maybe_proxy_openai", + "GGML": "ggml", + "TogetherLLM": "together", + "AnthropicLLM": "anthropic", + "ReplicateLLM": "replicate", + "Ollama": "ollama", +} + class Models(BaseModel): """Main class that holds the current model configuration""" @@ -29,6 +59,8 @@ class Models(BaseModel): edit: Optional[LLM] = None chat: Optional[LLM] = None + unused: List[LLM] = [] + # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models sdk: ContinueSDK = None diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 34c3ab74..b4548ff2 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -17,7 +17,7 @@ class GGML(LLM): # this is model-specific max_context_length: int = 2048 server_url: str = "http://localhost:8000" - verify_ssl: bool = True + verify_ssl: Optional[bool] = None requires_write_log = True diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 76331a28..7abd268d 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List +from typing import Any, Coroutine, Dict, Generator, List, Optional import aiohttp import requests @@ -15,7 +15,7 @@ class HuggingFaceInferenceAPI(LLM): hf_token: str max_context_length: int = 2048 - verify_ssl: bool = True + verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 276cc290..48e773a3 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -49,7 +49,7 @@ class OpenAI(LLM): api_key: str model: str openai_server_info: Optional[OpenAIServerInfo] = None - verify_ssl: bool = True + verify_ssl: Optional[bool] = None ca_bundle_path: Optional[str] = None requires_write_log = True @@ -73,7 +73,7 @@ class OpenAI(LLM): if self.openai_server_info.api_version is not None: openai.api_version = self.openai_server_info.api_version - if self.verify_ssl is False: + if self.verify_ssl is not None and self.verify_ssl is False: openai.verify_ssl_certs = False openai.ca_bundle_path = self.ca_bundle_path or certifi.where() diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 44f5030c..9a28de2d 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -1,5 +1,5 @@ import json -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import aiohttp @@ -14,7 +14,7 @@ class TogetherLLM(LLM): model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" max_context_length: int = 2048 base_url: str = "https://api.together.xyz" - verify_ssl: bool = True + verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 17ce27ec..7bdffc8e 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -1,5 +1,5 @@ import threading -from typing import Dict, List +from typing import Any, Dict, List import redbaron @@ -53,6 +53,47 @@ def edit_config_property(key_path: List[str], value: redbaron.RedBaron): file.write(red.dumps()) +def add_config_import(line: str): + with edit_lock: + red = load_red() + # check if the import already exists + for node in red: + if node.type == "import" and node.dumps() == line: + return + # if it doesn't exist, add it + red.insert(1, line) + + with open(getConfigFilePath(), "w") as file: + file.write(red.dumps()) + + +filtered_attrs = { + "requires_api_key", + "requires_unique_id", + "requires_write_log", + "class_name", + "name", + "llm", +} + + +def display_val(v: Any): + if isinstance(v, str): + return f'"{v}"' + return str(v) + + +def display_llm_class(llm): + args = ", ".join( + [ + f"{k}={display_val(v)}" + for k, v in llm.dict().items() + if k not in filtered_attrs and v is not None + ] + ) + return f"{llm.__class__.__name__}({args})" + + def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: args = [f"{key}={value}" for key, value in args.items()] return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0] diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py new file mode 100644 index 00000000..1c50c714 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -0,0 +1,24 @@ +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...libs.util.paths import getConfigFilePath + +MODEL_CLASS_TO_MESSAGE = { + "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.", + "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).", + "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", + "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", +} + + +class SetupModelStep(Step): + model_class: str + name: str = "Setup model in config.py" + + async def run(self, sdk: ContinueSDK): + await sdk.ide.setFileOpen(getConfigFilePath()) + self.description = MODEL_CLASS_TO_MESSAGE.get( + self.model_class, "Please finish setting up this model in `config.py`" + ) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index bdcaad47..55a5f3b4 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server from ..core.main import ContextItem +from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES from ..libs.util.create_async_task import create_async_task from ..libs.util.edit_config import ( + add_config_import, create_float_node, create_obj_node, create_string_node, + display_llm_class, ) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.setup_model import SetupModelStep from .gui_protocol import AbstractGUIProtocolServer from .session_manager import Session, session_manager @@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer): ) def set_model_for_role(self, role: str, model_class: str, model: Any): - prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) - if prev_model is not None: - prev_model.update(model) - self.session.autopilot.continue_sdk.models.__setattr__(role, model) - create_async_task( - self.session.autopilot.set_config_attr( - ["models", role], create_obj_node(model_class, {**model}) - ), - self.on_error, - ) + models = self.session.autopilot.continue_sdk.config.models + unused_models = models.unused + if role == "*": + + async def async_stuff(): + # Clear all of the models and store them in unused_models + # NOTE: There will be duplicates + for role in ALL_MODEL_ROLES: + prev_model = models.__getattribute__(role) + models.__setattr__(role, None) + if prev_model is not None: + exists = False + for other in unused_models: + if display_llm_class(prev_model) == display_llm_class( + other + ): + exists = True + break + if not exists: + unused_models.append(prev_model) + + # Replace default with either new one or existing from unused_models + for unused_model in unused_models: + if model_class == unused_model.__class__.__name__ and ( + "model" not in model or model["model"] == unused_model.model + ): + models.default = unused_model + + # Set and start the default model if didn't already exist from unused + if models.default is None: + models.default = MODEL_CLASSES[model_class](**model) + await self.session.autopilot.continue_sdk.start_model( + models.default + ) + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) + + models_args = {} + + for role in ALL_MODEL_ROLES: + val = models.__getattribute__(role) + if val is None: + continue # no pun intended + + models_args[role] = display_llm_class(val) + + JOINER = ", " + models_args[ + "unused" + ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]" + + await self.session.autopilot.set_config_attr( + ["models"], + create_obj_node("Models", models_args), + ) + + add_config_import( + f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" + ) + + for role in ALL_MODEL_ROLES: + if role != "default": + models.__setattr__(role, models.default) + + create_async_task(async_stuff(), self.on_error) + else: + # TODO + pass def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( |