From 0dfdd4c52a9d686af54346ade35e0bcff226c8b9 Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Fri, 29 Sep 2023 20:20:45 -0700 Subject: Model config UI (#522) * feat: :sparkles: improved model selection * feat: :sparkles: add max_tokens option to LLM class * docs: :memo: update reference with max_tokens * feat: :loud_sound: add context to dev data loggign * feat: :sparkles: final work on model config ui --- extension/react-app/src/pages/modelconfig.tsx | 261 ++++++++++++++++++++++++++ extension/react-app/src/pages/models.tsx | 152 +++------------ 2 files changed, 286 insertions(+), 127 deletions(-) create mode 100644 extension/react-app/src/pages/modelconfig.tsx (limited to 'extension/react-app/src/pages') diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx new file mode 100644 index 00000000..97e2d76c --- /dev/null +++ b/extension/react-app/src/pages/modelconfig.tsx @@ -0,0 +1,261 @@ +import React, { useCallback, useContext, useEffect, useState } from "react"; +import ModelCard from "../components/ModelCard"; +import styled from "styled-components"; +import { ArrowLeftIcon } from "@heroicons/react/24/outline"; +import { + TextInput, + defaultBorderRadius, + lightGray, + vscBackground, +} from "../components"; +import { Form, useNavigate } from "react-router-dom"; +import { useDispatch, useSelector } from "react-redux"; +import { GUIClientContext } from "../App"; +import { setShowDialog } from "../redux/slices/uiStateSlice"; +import { useParams } from "react-router-dom"; +import { + MODEL_INFO, + MODEL_PROVIDER_TAG_COLORS, + ModelInfo, +} from "../util/modelData"; +import { RootStore } from "../redux/store"; +import StyledMarkdownPreview from "../components/StyledMarkdownPreview"; +import { getFontSize } from "../util"; +import { FormProvider, useForm } from "react-hook-form"; + +const GridDiv = styled.div` + display: grid; + grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + grid-gap: 2rem; + padding: 1rem; + justify-items: center; + align-items: center; +`; + +const CustomModelButton = styled.div<{ disabled: boolean }>` + border: 1px solid ${lightGray}; + border-radius: ${defaultBorderRadius}; + padding: 4px 8px; + display: flex; + justify-content: center; + align-items: center; + width: 100%; + transition: all 0.5s; + + ${(props) => + props.disabled + ? ` + opacity: 0.5; + ` + : ` + &:hover { + border: 1px solid #be1b55; + background-color: #be1b5522; + cursor: pointer; + } + `} +`; + +function ModelConfig() { + const formMethods = useForm(); + const { modelName } = useParams(); + + const [modelInfo, setModelInfo] = useState(undefined); + + useEffect(() => { + if (modelName) { + setModelInfo(MODEL_INFO[modelName]); + } + }, [modelName]); + + const client = useContext(GUIClientContext); + const dispatch = useDispatch(); + const navigate = useNavigate(); + const vscMediaUrl = useSelector( + (state: RootStore) => state.config.vscMediaUrl + ); + + const disableModelCards = useCallback(() => { + return ( + modelInfo?.collectInputFor?.some((d) => { + if (!d.required) return false; + const val = formMethods.watch(d.key); + return ( + typeof val === "undefined" || (typeof val === "string" && val === "") + ); + }) || false + ); + }, [modelInfo, formMethods]); + + return ( + +
+
+ navigate("/models")} + className="inline-block ml-4 cursor-pointer" + /> +

+ Configure Model +

+
+ +
+
+ {vscMediaUrl && ( + + )} +

{modelInfo?.title}

+
+ {modelInfo?.tags?.map((tag) => { + return ( + + {tag} + + ); + })} + +
+ + {(modelInfo?.collectInputFor?.filter((d) => d.required).length || 0) > + 0 && ( + <> +

Enter required parameters

+ + {modelInfo?.collectInputFor?.map((d) => { + return ( +
+ + +
+ ); + })} + + )} + + {(modelInfo?.collectInputFor?.filter((d) => !d.required).length || + 0) > 0 && ( +
+ + Advanced (optional) + + + {modelInfo?.collectInputFor?.map((d) => { + if (d.required) return null; + return ( +
+ + +
+ ); + })} +
+ )} + +

Select a model preset

+
+ + {modelInfo?.packages.map((pkg) => { + return ( + { + if (disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...pkg.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + /> + ); + })} + + { + if (!modelInfo || disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...modelInfo.packages[0]?.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + > +

Configure Model in config.py

+
+
+
+
+ ); +} + +export default ModelConfig; diff --git a/extension/react-app/src/pages/models.tsx b/extension/react-app/src/pages/models.tsx index c20d820c..a9a97a13 100644 --- a/extension/react-app/src/pages/models.tsx +++ b/extension/react-app/src/pages/models.tsx @@ -1,131 +1,13 @@ -import React from "react"; -import ModelCard, { ModelInfo, ModelTag } from "../components/ModelCard"; +import React, { useContext } from "react"; +import ModelCard from "../components/ModelCard"; import styled from "styled-components"; import { ArrowLeftIcon } from "@heroicons/react/24/outline"; import { lightGray, vscBackground } from "../components"; import { useNavigate } from "react-router-dom"; - -const MODEL_INFO: ModelInfo[] = [ - { - title: "OpenAI", - class: "OpenAI", - description: "Use gpt-4, gpt-3.5-turbo, or any other OpenAI model", - args: { - model: "gpt-4", - api_key: "", - title: "OpenAI", - }, - icon: "openai.svg", - tags: [ModelTag["Requires API Key"]], - }, - { - title: "Anthropic", - class: "AnthropicLLM", - description: - "Claude-2 is a highly capable model with a 100k context length", - args: { - model: "claude-2", - api_key: "", - title: "Anthropic", - }, - icon: "anthropic.png", - tags: [ModelTag["Requires API Key"]], - }, - { - title: "Ollama", - class: "Ollama", - description: - "One of the fastest ways to get started with local models on Mac or Linux", - args: { - model: "codellama", - title: "Ollama", - }, - icon: "ollama.png", - tags: [ModelTag["Local"], ModelTag["Open-Source"]], - }, - { - title: "TogetherAI", - class: "TogetherLLM", - description: - "Use the TogetherAI API for extremely fast streaming of open-source models", - args: { - model: "togethercomputer/CodeLlama-13b-Instruct", - api_key: "", - title: "TogetherAI", - }, - icon: "together.png", - tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]], - }, - { - title: "LM Studio", - class: "GGML", - description: - "One of the fastest ways to get started with local models on Mac or Windows", - args: { - server_url: "http://localhost:1234", - title: "LM Studio", - }, - icon: "lmstudio.png", - tags: [ModelTag["Local"], ModelTag["Open-Source"]], - }, - { - title: "Replicate", - class: "ReplicateLLM", - description: "Use the Replicate API to run open-source models", - args: { - model: - "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781", - api_key: "", - title: "Replicate", - }, - icon: "replicate.png", - tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]], - }, - { - title: "llama.cpp", - class: "LlamaCpp", - description: "If you are running the llama.cpp server from source", - args: { - title: "llama.cpp", - }, - icon: "llamacpp.png", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "HuggingFace TGI", - class: "HuggingFaceTGI", - description: - "HuggingFace Text Generation Inference is an advanced, highly performant option for serving open-source models to multiple people", - args: { - title: "HuggingFace TGI", - }, - icon: "hf.png", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "Other OpenAI-compatible API", - class: "GGML", - description: - "If you are using any other OpenAI-compatible API, for example text-gen-webui, FastChat, LocalAI, or llama-cpp-python, you can simply enter your server URL", - args: { - server_url: "", - }, - icon: "openai.svg", - tags: [ModelTag.Local, ModelTag["Open-Source"]], - }, - { - title: "GPT-4 limited free trial", - class: "OpenAIFreeTrial", - description: - "New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key", - args: { - model: "gpt-4", - title: "GPT-4 Free Trial", - }, - icon: "openai.svg", - tags: [ModelTag.Free], - }, -]; +import { useDispatch } from "react-redux"; +import { GUIClientContext } from "../App"; +import { setShowDialog } from "../redux/slices/uiStateSlice"; +import { MODEL_INFO } from "../util/modelData"; const GridDiv = styled.div` display: grid; @@ -138,6 +20,8 @@ const GridDiv = styled.div` function Models() { const navigate = useNavigate(); + const client = useContext(GUIClientContext); + const dispatch = useDispatch(); return (
navigate("/")} className="inline-block ml-4 cursor-pointer" /> -

Add a new model

+

+ Select LLM Provider +

- {MODEL_INFO.map((model) => ( - + {Object.entries(MODEL_INFO).map(([name, modelInfo]) => ( + { + if ((e.target as any).closest("a")) { + return; + } + navigate(`/modelconfig/${name}`); + }} + /> ))}
-- cgit v1.2.3-70-g09d2