From 0dfdd4c52a9d686af54346ade35e0bcff226c8b9 Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Fri, 29 Sep 2023 20:20:45 -0700 Subject: Model config UI (#522) * feat: :sparkles: improved model selection * feat: :sparkles: add max_tokens option to LLM class * docs: :memo: update reference with max_tokens * feat: :loud_sound: add context to dev data loggign * feat: :sparkles: final work on model config ui --- extension/react-app/src/pages/modelconfig.tsx | 261 ++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 extension/react-app/src/pages/modelconfig.tsx (limited to 'extension/react-app/src/pages/modelconfig.tsx') diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx new file mode 100644 index 00000000..97e2d76c --- /dev/null +++ b/extension/react-app/src/pages/modelconfig.tsx @@ -0,0 +1,261 @@ +import React, { useCallback, useContext, useEffect, useState } from "react"; +import ModelCard from "../components/ModelCard"; +import styled from "styled-components"; +import { ArrowLeftIcon } from "@heroicons/react/24/outline"; +import { + TextInput, + defaultBorderRadius, + lightGray, + vscBackground, +} from "../components"; +import { Form, useNavigate } from "react-router-dom"; +import { useDispatch, useSelector } from "react-redux"; +import { GUIClientContext } from "../App"; +import { setShowDialog } from "../redux/slices/uiStateSlice"; +import { useParams } from "react-router-dom"; +import { + MODEL_INFO, + MODEL_PROVIDER_TAG_COLORS, + ModelInfo, +} from "../util/modelData"; +import { RootStore } from "../redux/store"; +import StyledMarkdownPreview from "../components/StyledMarkdownPreview"; +import { getFontSize } from "../util"; +import { FormProvider, useForm } from "react-hook-form"; + +const GridDiv = styled.div` + display: grid; + grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + grid-gap: 2rem; + padding: 1rem; + justify-items: center; + align-items: center; +`; + +const CustomModelButton = styled.div<{ disabled: boolean }>` + border: 1px solid ${lightGray}; + border-radius: ${defaultBorderRadius}; + padding: 4px 8px; + display: flex; + justify-content: center; + align-items: center; + width: 100%; + transition: all 0.5s; + + ${(props) => + props.disabled + ? ` + opacity: 0.5; + ` + : ` + &:hover { + border: 1px solid #be1b55; + background-color: #be1b5522; + cursor: pointer; + } + `} +`; + +function ModelConfig() { + const formMethods = useForm(); + const { modelName } = useParams(); + + const [modelInfo, setModelInfo] = useState(undefined); + + useEffect(() => { + if (modelName) { + setModelInfo(MODEL_INFO[modelName]); + } + }, [modelName]); + + const client = useContext(GUIClientContext); + const dispatch = useDispatch(); + const navigate = useNavigate(); + const vscMediaUrl = useSelector( + (state: RootStore) => state.config.vscMediaUrl + ); + + const disableModelCards = useCallback(() => { + return ( + modelInfo?.collectInputFor?.some((d) => { + if (!d.required) return false; + const val = formMethods.watch(d.key); + return ( + typeof val === "undefined" || (typeof val === "string" && val === "") + ); + }) || false + ); + }, [modelInfo, formMethods]); + + return ( + +
+
+ navigate("/models")} + className="inline-block ml-4 cursor-pointer" + /> +

+ Configure Model +

+
+ +
+
+ {vscMediaUrl && ( + + )} +

{modelInfo?.title}

+
+ {modelInfo?.tags?.map((tag) => { + return ( + + {tag} + + ); + })} + +
+ + {(modelInfo?.collectInputFor?.filter((d) => d.required).length || 0) > + 0 && ( + <> +

Enter required parameters

+ + {modelInfo?.collectInputFor?.map((d) => { + return ( +
+ + +
+ ); + })} + + )} + + {(modelInfo?.collectInputFor?.filter((d) => !d.required).length || + 0) > 0 && ( +
+ + Advanced (optional) + + + {modelInfo?.collectInputFor?.map((d) => { + if (d.required) return null; + return ( +
+ + +
+ ); + })} +
+ )} + +

Select a model preset

+
+ + {modelInfo?.packages.map((pkg) => { + return ( + { + if (disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...pkg.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + /> + ); + })} + + { + if (!modelInfo || disableModelCards()) return; + const formParams: any = {}; + for (const d of modelInfo.collectInputFor || []) { + formParams[d.key] = + d.inputType === "text" + ? formMethods.watch(d.key) + : parseInt(formMethods.watch(d.key)); + } + + client?.addModelForRole("*", modelInfo.class, { + ...modelInfo.packages[0]?.params, + ...modelInfo.params, + ...formParams, + }); + navigate("/"); + }} + > +

Configure Model in config.py

+
+
+
+
+ ); +} + +export default ModelConfig; -- cgit v1.2.3-70-g09d2