diff options
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 34 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 43 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 24 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 83 | ||||
-rw-r--r-- | extension/react-app/src/components/Layout.tsx | 48 | ||||
-rw-r--r-- | extension/react-app/src/components/ModelSelect.tsx | 113 |
10 files changed, 316 insertions, 43 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index e4610d36..a8c97622 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,8 +1,15 @@ -from typing import Optional +from typing import List, Optional from pydantic import BaseModel from ..libs.llm import LLM +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.ggml import GGML +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ..libs.llm.ollama import Ollama +from ..libs.llm.openai import OpenAI +from ..libs.llm.replicate import ReplicateLLM +from ..libs.llm.together import TogetherLLM class ContinueSDK(BaseModel): @@ -18,6 +25,29 @@ ALL_MODEL_ROLES = [ "chat", ] +MODEL_CLASSES = { + cls.__name__: cls + for cls in [ + OpenAI, + MaybeProxyOpenAI, + GGML, + TogetherLLM, + AnthropicLLM, + ReplicateLLM, + Ollama, + ] +} + +MODEL_MODULE_NAMES = { + "OpenAI": "openai", + "MaybeProxyOpenAI": "maybe_proxy_openai", + "GGML": "ggml", + "TogetherLLM": "together", + "AnthropicLLM": "anthropic", + "ReplicateLLM": "replicate", + "Ollama": "ollama", +} + class Models(BaseModel): """Main class that holds the current model configuration""" @@ -29,6 +59,8 @@ class Models(BaseModel): edit: Optional[LLM] = None chat: Optional[LLM] = None + unused: List[LLM] = [] + # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models sdk: ContinueSDK = None diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 34c3ab74..b4548ff2 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -17,7 +17,7 @@ class GGML(LLM): # this is model-specific max_context_length: int = 2048 server_url: str = "http://localhost:8000" - verify_ssl: bool = True + verify_ssl: Optional[bool] = None requires_write_log = True diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 76331a28..7abd268d 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List +from typing import Any, Coroutine, Dict, Generator, List, Optional import aiohttp import requests @@ -15,7 +15,7 @@ class HuggingFaceInferenceAPI(LLM): hf_token: str max_context_length: int = 2048 - verify_ssl: bool = True + verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 276cc290..48e773a3 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -49,7 +49,7 @@ class OpenAI(LLM): api_key: str model: str openai_server_info: Optional[OpenAIServerInfo] = None - verify_ssl: bool = True + verify_ssl: Optional[bool] = None ca_bundle_path: Optional[str] = None requires_write_log = True @@ -73,7 +73,7 @@ class OpenAI(LLM): if self.openai_server_info.api_version is not None: openai.api_version = self.openai_server_info.api_version - if self.verify_ssl is False: + if self.verify_ssl is not None and self.verify_ssl is False: openai.verify_ssl_certs = False openai.ca_bundle_path = self.ca_bundle_path or certifi.where() diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 44f5030c..9a28de2d 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -1,5 +1,5 @@ import json -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import aiohttp @@ -14,7 +14,7 @@ class TogetherLLM(LLM): model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" max_context_length: int = 2048 base_url: str = "https://api.together.xyz" - verify_ssl: bool = True + verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 17ce27ec..7bdffc8e 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -1,5 +1,5 @@ import threading -from typing import Dict, List +from typing import Any, Dict, List import redbaron @@ -53,6 +53,47 @@ def edit_config_property(key_path: List[str], value: redbaron.RedBaron): file.write(red.dumps()) +def add_config_import(line: str): + with edit_lock: + red = load_red() + # check if the import already exists + for node in red: + if node.type == "import" and node.dumps() == line: + return + # if it doesn't exist, add it + red.insert(1, line) + + with open(getConfigFilePath(), "w") as file: + file.write(red.dumps()) + + +filtered_attrs = { + "requires_api_key", + "requires_unique_id", + "requires_write_log", + "class_name", + "name", + "llm", +} + + +def display_val(v: Any): + if isinstance(v, str): + return f'"{v}"' + return str(v) + + +def display_llm_class(llm): + args = ", ".join( + [ + f"{k}={display_val(v)}" + for k, v in llm.dict().items() + if k not in filtered_attrs and v is not None + ] + ) + return f"{llm.__class__.__name__}({args})" + + def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: args = [f"{key}={value}" for key, value in args.items()] return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0] diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py new file mode 100644 index 00000000..1c50c714 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -0,0 +1,24 @@ +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...libs.util.paths import getConfigFilePath + +MODEL_CLASS_TO_MESSAGE = { + "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.", + "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).", + "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", + "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", +} + + +class SetupModelStep(Step): + model_class: str + name: str = "Setup model in config.py" + + async def run(self, sdk: ContinueSDK): + await sdk.ide.setFileOpen(getConfigFilePath()) + self.description = MODEL_CLASS_TO_MESSAGE.get( + self.model_class, "Please finish setting up this model in `config.py`" + ) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index bdcaad47..55a5f3b4 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server from ..core.main import ContextItem +from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES from ..libs.util.create_async_task import create_async_task from ..libs.util.edit_config import ( + add_config_import, create_float_node, create_obj_node, create_string_node, + display_llm_class, ) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.setup_model import SetupModelStep from .gui_protocol import AbstractGUIProtocolServer from .session_manager import Session, session_manager @@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer): ) def set_model_for_role(self, role: str, model_class: str, model: Any): - prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) - if prev_model is not None: - prev_model.update(model) - self.session.autopilot.continue_sdk.models.__setattr__(role, model) - create_async_task( - self.session.autopilot.set_config_attr( - ["models", role], create_obj_node(model_class, {**model}) - ), - self.on_error, - ) + models = self.session.autopilot.continue_sdk.config.models + unused_models = models.unused + if role == "*": + + async def async_stuff(): + # Clear all of the models and store them in unused_models + # NOTE: There will be duplicates + for role in ALL_MODEL_ROLES: + prev_model = models.__getattribute__(role) + models.__setattr__(role, None) + if prev_model is not None: + exists = False + for other in unused_models: + if display_llm_class(prev_model) == display_llm_class( + other + ): + exists = True + break + if not exists: + unused_models.append(prev_model) + + # Replace default with either new one or existing from unused_models + for unused_model in unused_models: + if model_class == unused_model.__class__.__name__ and ( + "model" not in model or model["model"] == unused_model.model + ): + models.default = unused_model + + # Set and start the default model if didn't already exist from unused + if models.default is None: + models.default = MODEL_CLASSES[model_class](**model) + await self.session.autopilot.continue_sdk.start_model( + models.default + ) + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) + + models_args = {} + + for role in ALL_MODEL_ROLES: + val = models.__getattribute__(role) + if val is None: + continue # no pun intended + + models_args[role] = display_llm_class(val) + + JOINER = ", " + models_args[ + "unused" + ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]" + + await self.session.autopilot.set_config_attr( + ["models"], + create_obj_node("Models", models_args), + ) + + add_config_import( + f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" + ) + + for role in ALL_MODEL_ROLES: + if role != "default": + models.__setattr__(role, models.default) + + create_async_task(async_stuff(), self.on_error) + else: + # TODO + pass def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx index c0f0929b..897fd683 100644 --- a/extension/react-app/src/components/Layout.tsx +++ b/extension/react-app/src/components/Layout.tsx @@ -22,6 +22,7 @@ import { } from "@heroicons/react/24/outline"; import HeaderButtonWithText from "./HeaderButtonWithText"; import { useNavigate } from "react-router-dom"; +import ModelSelect from "./ModelSelect"; // #region Styled Components @@ -138,30 +139,29 @@ const Layout = () => { {bottomMessage} </BottomMessageDiv> <Footer> - <SparklesIcon - visibility={ - localStorage.getItem("hideFeature") === "true" - ? "hidden" - : "visible" - } - className="mr-auto cursor-pointer" - onClick={() => { - localStorage.setItem("hideFeature", "true"); - }} - onMouseEnter={() => { - dispatch( - setBottomMessage( - "🎁 New Feature: Use ⌘⇧R automatically debug errors in the terminal" - ) - ); - }} - onMouseLeave={() => { - dispatch(setBottomMessage(undefined)); - }} - width="1.3em" - height="1.3em" - color="yellow" - /> + {localStorage.getItem("hideFeature") === "true" || ( + <SparklesIcon + className="mr-auto cursor-pointer" + onClick={() => { + localStorage.setItem("hideFeature", "true"); + }} + onMouseEnter={() => { + dispatch( + setBottomMessage( + "🎁 New Feature: Use ⌘⇧R automatically debug errors in the terminal (you can click the sparkle icon to make it go away)" + ) + ); + }} + onMouseLeave={() => { + dispatch(setBottomMessage(undefined)); + }} + width="1.3em" + height="1.3em" + color="yellow" + /> + )} + + <ModelSelect /> <HeaderButtonWithText onClick={() => { client?.loadSession(undefined); diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx new file mode 100644 index 00000000..49788143 --- /dev/null +++ b/extension/react-app/src/components/ModelSelect.tsx @@ -0,0 +1,113 @@ +import styled from "styled-components"; +import { + defaultBorderRadius, + secondaryDark, + vscBackground, + vscForeground, +} from "."; +import { useContext, useEffect } from "react"; +import { GUIClientContext } from "../App"; +import { RootStore } from "../redux/store"; +import { useSelector } from "react-redux"; + +const MODEL_INFO: { title: string; class: string; args: any }[] = [ + { + title: "gpt-4", + class: "MaybeProxyOpenAI", + args: { + model: "gpt-4", + api_key: "", + }, + }, + { + title: "gpt-3.5-turbo", + class: "MaybeProxyOpenAI", + args: { + model: "gpt-3.5-turbo", + api_key: "", + }, + }, + { + title: "claude-2", + class: "AnthropicLLM", + args: { + model: "claude-2", + api_key: "<ANTHROPIC_API_KEY>", + }, + }, + { + title: "GGML", + class: "GGML", + args: {}, + }, + { + title: "Ollama", + class: "Ollama", + args: { + model: "codellama", + }, + }, + { + title: "Replicate", + class: "ReplicateLLM", + args: { + model: + "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781", + api_key: "<REPLICATE_API_KEY>", + }, + }, + { + title: "TogetherAI", + class: "TogetherLLM", + args: { + model: "gpt-4", + api_key: "<TOGETHER_API_KEY>", + }, + }, +]; + +const Select = styled.select` + border: none; + width: fit-content; + background-color: ${secondaryDark}; + color: ${vscForeground}; + border-radius: ${defaultBorderRadius}; + padding: 6px; + /* box-shadow: 0px 0px 1px 0px ${vscForeground}; */ + max-height: 35vh; + overflow: scroll; + cursor: pointer; + margin-right: auto; + + &:focus { + outline: none; + } +`; + +function ModelSelect(props: {}) { + const client = useContext(GUIClientContext); + const defaultModel = useSelector( + (state: RootStore) => + (state.serverState.config as any)?.models?.default?.class_name + ); + + return ( + <Select + defaultValue={0} + onChange={(e) => { + const model = MODEL_INFO[parseInt(e.target.value)]; + client?.setModelForRole("*", model.class, model.args); + }} + > + {MODEL_INFO.map((model, idx) => { + return ( + <option selected={defaultModel === model.class} value={idx}> + {model.title} + </option> + ); + })} + </Select> + ); +} + +export default ModelSelect; |