diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-29 20:20:45 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-29 20:20:45 -0700 |
commit | 0dfdd4c52a9d686af54346ade35e0bcff226c8b9 (patch) | |
tree | d4f98c7809ddfc7ed14e3be36fe921cc418a8917 /extension/react-app/src/pages | |
parent | 64558321addcc80de9137cf9c9ef1bf7ed85ffa5 (diff) | |
download | sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.gz sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.bz2 sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.zip |
Model config UI (#522)
* feat: :sparkles: improved model selection
* feat: :sparkles: add max_tokens option to LLM class
* docs: :memo: update reference with max_tokens
* feat: :loud_sound: add context to dev data loggign
* feat: :sparkles: final work on model config ui
Diffstat (limited to 'extension/react-app/src/pages')
-rw-r--r-- | extension/react-app/src/pages/modelconfig.tsx | 261 | ||||
-rw-r--r-- | extension/react-app/src/pages/models.tsx | 152 |
2 files changed, 286 insertions, 127 deletions
diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx new file mode 100644 index 00000000..97e2d76c --- /dev/null +++ b/extension/react-app/src/pages/modelconfig.tsx @@ -0,0 +1,261 @@ +import React, { useCallback, useContext, useEffect, useState } from "react"; +import ModelCard from "../components/ModelCard"; +import styled from "styled-components"; +import { ArrowLeftIcon } from "@heroicons/react/24/outline"; +import { + TextInput, + defaultBorderRadius, + lightGray, + vscBackground, +} from "../components"; +import { Form, useNavigate } from "react-router-dom"; +import { useDispatch, useSelector } from "react-redux"; +import { GUIClientContext } from "../App"; +import { setShowDialog } from "../redux/slices/uiStateSlice"; +import { useParams } from "react-router-dom"; +import { + MODEL_INFO, + MODEL_PROVIDER_TAG_COLORS, + ModelInfo, +} from "../util/modelData"; +import { RootStore } from "../redux/store"; +import StyledMarkdownPreview from "../components/StyledMarkdownPreview"; +import { getFontSize } from "../util"; +import { FormProvider, useForm } from "react-hook-form"; + +const GridDiv = styled.div` + display: grid; + grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + grid-gap: 2rem; + padding: 1rem; + justify-items: center; + align-items: center; +`; + +const CustomModelButton = styled.div<{ disabled: boolean }>` + border: 1px solid ${lightGray}; + border-radius: ${defaultBorderRadius}; + padding: 4px 8px; + display: flex; + justify-content: center; + align-items: center; + width: 100%; + transition: all 0.5s; + + ${(props) => + props.disabled + ? ` + opacity: 0.5; + ` + : ` + &:hover { + border: 1px solid #be1b55; + background-color: #be1b5522; + cursor: pointer; + } + `} +`; + +function ModelConfig() { + const formMethods = useForm(); + const { modelName } = useParams(); + + const [modelInfo, setModelInfo] = useState<ModelInfo | undefined>(undefined); + + useEffect(() => { + if (modelName) { + setModelInfo(MODEL_INFO[modelName]); + } + }, [modelName]); + + const client = useContext(GUIClientContext); + const dispatch = useDispatch(); + const navigate = useNavigate(); + const vscMediaUrl = useSelector( + (state: RootStore) => state.config.vscMediaUrl + ); + + const disableModelCards = useCallback(() => { + return ( + modelInfo?.collectInputFor?.some((d) => { + if (!d.required) return false; + const val = formMethods.watch(d.key); + return ( + typeof val === "undefined" || (typeof val === "string" && val === "") + ); + }) || false + ); + }, [modelInfo, formMethods]); + + return ( + <FormProvider {...formMethods}> + <div className="overflow-y-scroll"> + <div + className="items-center flex m-0 p-0 sticky top-0" + style={{ + borderBottom: `0.5px solid ${lightGray}`, + backgroundColor: vscBackground, + zIndex: 2, + }} + > + <ArrowLeftIcon + width="1.2em" + height="1.2em" + onClick={() => navigate("/models")} + className="inline-block ml-4 cursor-pointer" + /> + <h3 className="text-lg font-bold m-2 inline-block"> + Configure Model + </h3> + </div> + + <div className="px-2"> + <div style={{ display: "flex", alignItems: "center" }}> + {vscMediaUrl && ( + <img + src={`${vscMediaUrl}/logos/${modelInfo?.icon}`} + height="24px" + style={{ marginRight: "10px" }} + /> + )} + <h2>{modelInfo?.title}</h2> + </div> + {modelInfo?.tags?.map((tag) => { + return ( + <span + style={{ + backgroundColor: `${MODEL_PROVIDER_TAG_COLORS[tag]}55`, + color: "white", + padding: "2px 4px", + borderRadius: defaultBorderRadius, + marginRight: "4px", + }} + > + {tag} + </span> + ); + })} + <StyledMarkdownPreview + className="mt-2" + fontSize={getFontSize()} + source={modelInfo?.longDescription || modelInfo?.description || ""} + wrapperElement={{ + "data-color-mode": "dark", + }} + maxHeight={200} + /> + <br /> + + {(modelInfo?.collectInputFor?.filter((d) => d.required).length || 0) > + 0 && ( + <> + <h3 className="mb-2">Enter required parameters</h3> + + {modelInfo?.collectInputFor?.map((d) => { + return ( + <div> + <label htmlFor={d.key}>{d.key}</label> + <TextInput + id={d.key} + className="border-2 border-gray-200 rounded-md p-2 m-2" + placeholder={d.key} + defaultValue={d.defaultValue} + {...formMethods.register(d.key, { + required: true, + })} + /> + </div> + ); + })} + </> + )} + + {(modelInfo?.collectInputFor?.filter((d) => !d.required).length || + 0) > 0 && ( + <details> + <summary className="mb-2"> + <b>Advanced (optional)</b> + </summary> + + {modelInfo?.collectInputFor?.map((d) => { + if (d.required) return null; + return ( + <div> + <label htmlFor={d.key}>{d.key}</label> + <TextInput + id={d.key} + className="border-2 border-gray-200 rounded-md p-2 m-2" + placeholder={d.key} + defaultValue={d.defaultValue} + {...formMethods.register(d.key, { + required: false, + })} + /> + </div> + ); + })} + </details> + )} + + <h3 className="mb-2">Select a model preset</h3> + </div> + <GridDiv> + {modelInfo?.packages.map((pkg) => { + return ( + <ModelCard + disabled={disableModelCards()} + title={pkg.title} + description={pkg.description} + tags={pkg.tags} + refUrl={pkg.refUrl} + icon={pkg.icon || modelInfo.icon} + onClick={(e) => { + if (disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...pkg.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + /> + ); + })} + + <CustomModelButton + disabled={disableModelCards()} + onClick={(e) => { + if (!modelInfo || disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...modelInfo.packages[0]?.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + > + <h3 className="text-center my-2">Configure Model in config.py</h3> + </CustomModelButton> + </GridDiv> + </div> + </FormProvider> + ); +} + +export default ModelConfig; diff --git a/extension/react-app/src/pages/models.tsx b/extension/react-app/src/pages/models.tsx index c20d820c..a9a97a13 100644 --- a/extension/react-app/src/pages/models.tsx +++ b/extension/react-app/src/pages/models.tsx @@ -1,131 +1,13 @@ -import React from "react"; -import ModelCard, { ModelInfo, ModelTag } from "../components/ModelCard"; +import React, { useContext } from "react"; +import ModelCard from "../components/ModelCard"; import styled from "styled-components"; import { ArrowLeftIcon } from "@heroicons/react/24/outline"; import { lightGray, vscBackground } from "../components"; import { useNavigate } from "react-router-dom"; - -const MODEL_INFO: ModelInfo[] = [ - { - title: "OpenAI", - class: "OpenAI", - description: "Use gpt-4, gpt-3.5-turbo, or any other OpenAI model", - args: { - model: "gpt-4", - api_key: "", - title: "OpenAI", - }, - icon: "openai.svg", - tags: [ModelTag["Requires API Key"]], - }, - { - title: "Anthropic", - class: "AnthropicLLM", - description: - "Claude-2 is a highly capable model with a 100k context length", - args: { - model: "claude-2", - api_key: "<ANTHROPIC_API_KEY>", - title: "Anthropic", - }, - icon: "anthropic.png", - tags: [ModelTag["Requires API Key"]], - }, - { - title: "Ollama", - class: "Ollama", - description: - "One of the fastest ways to get started with local models on Mac or Linux", - args: { - model: "codellama", - title: "Ollama", - }, - icon: "ollama.png", - tags: [ModelTag["Local"], ModelTag["Open-Source"]], - }, - { - title: "TogetherAI", - class: "TogetherLLM", - description: - "Use the TogetherAI API for extremely fast streaming of open-source models", - args: { - model: "togethercomputer/CodeLlama-13b-Instruct", - api_key: "<TOGETHER_API_KEY>", - title: "TogetherAI", - }, - icon: "together.png", - tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]], - }, - { - title: "LM Studio", - class: "GGML", - description: - "One of the fastest ways to get started with local models on Mac or Windows", - args: { - server_url: "http://localhost:1234", - title: "LM Studio", - }, - icon: "lmstudio.png", - tags: [ModelTag["Local"], ModelTag["Open-Source"]], - }, - { - title: "Replicate", - class: "ReplicateLLM", - description: "Use the Replicate API to run open-source models", - args: { - model: - "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781", - api_key: "<REPLICATE_API_KEY>", - title: "Replicate", - }, - icon: "replicate.png", - tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]], - }, - { - title: "llama.cpp", - class: "LlamaCpp", - description: "If you are running the llama.cpp server from source", - args: { - title: "llama.cpp", - }, - icon: "llamacpp.png", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "HuggingFace TGI", - class: "HuggingFaceTGI", - description: - "HuggingFace Text Generation Inference is an advanced, highly performant option for serving open-source models to multiple people", - args: { - title: "HuggingFace TGI", - }, - icon: "hf.png", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "Other OpenAI-compatible API", - class: "GGML", - description: - "If you are using any other OpenAI-compatible API, for example text-gen-webui, FastChat, LocalAI, or llama-cpp-python, you can simply enter your server URL", - args: { - server_url: "<SERVER_URL>", - }, - icon: "openai.svg", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "GPT-4 limited free trial", - class: "OpenAIFreeTrial", - description: - "New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key", - args: { - model: "gpt-4", - title: "GPT-4 Free Trial", - }, - icon: "openai.svg", - tags: [ModelTag.Free], - }, -]; +import { useDispatch } from "react-redux"; +import { GUIClientContext } from "../App"; +import { setShowDialog } from "../redux/slices/uiStateSlice"; +import { MODEL_INFO } from "../util/modelData"; const GridDiv = styled.div` display: grid; @@ -138,6 +20,8 @@ const GridDiv = styled.div` function Models() { const navigate = useNavigate(); + const client = useContext(GUIClientContext); + const dispatch = useDispatch(); return ( <div className="overflow-y-scroll"> <div @@ -154,11 +38,25 @@ function Models() { onClick={() => navigate("/")} className="inline-block ml-4 cursor-pointer" /> - <h3 className="text-lg font-bold m-2 inline-block">Add a new model</h3> + <h3 className="text-lg font-bold m-2 inline-block"> + Select LLM Provider + </h3> </div> <GridDiv> - {MODEL_INFO.map((model) => ( - <ModelCard modelInfo={model} /> + {Object.entries(MODEL_INFO).map(([name, modelInfo]) => ( + <ModelCard + title={modelInfo.title} + description={modelInfo.description} + tags={modelInfo.tags} + icon={modelInfo.icon} + refUrl={`https://continue.dev/docs/reference/Models/${modelInfo.class.toLowerCase()}`} + onClick={(e) => { + if ((e.target as any).closest("a")) { + return; + } + navigate(`/modelconfig/${name}`); + }} + /> ))} </GridDiv> </div> |