diff options
Diffstat (limited to 'extension/react-app/src/components/ModelSelect.tsx')
-rw-r--r-- | extension/react-app/src/components/ModelSelect.tsx | 164 |
1 files changed, 110 insertions, 54 deletions
diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx index 317c164a..dc58da9e 100644 --- a/extension/react-app/src/components/ModelSelect.tsx +++ b/extension/react-app/src/components/ModelSelect.tsx @@ -1,18 +1,21 @@ import styled from "styled-components"; import { defaultBorderRadius, + lightGray, secondaryDark, vscBackground, vscForeground, } from "."; -import { useContext, useEffect } from "react"; +import { useContext } from "react"; import { GUIClientContext } from "../App"; import { RootStore } from "../redux/store"; -import { useSelector } from "react-redux"; +import { useDispatch, useSelector } from "react-redux"; +import { PlusIcon } from "@heroicons/react/24/outline"; +import { setDialogMessage, setShowDialog } from "../redux/slices/uiStateSlice"; const MODEL_INFO: { title: string; class: string; args: any }[] = [ { - title: "gpt-4", + title: "OpenAI", class: "OpenAI", args: { model: "gpt-4", @@ -20,7 +23,7 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ }, }, { - title: "claude-2", + title: "Anthropic", class: "AnthropicLLM", args: { model: "claude-2", @@ -28,11 +31,6 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ }, }, { - title: "GGML", - class: "GGML", - args: {}, - }, - { title: "Ollama", class: "Ollama", args: { @@ -40,6 +38,14 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ }, }, { + title: "TogetherAI", + class: "TogetherLLM", + args: { + model: "togethercomputer/CodeLlama-13b-Instruct", + api_key: "<TOGETHER_API_KEY>", + }, + }, + { title: "Replicate", class: "ReplicateLLM", args: { @@ -49,20 +55,19 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ }, }, { - title: "TogetherAI", - class: "TogetherLLM", - args: { - model: "togethercomputer/CodeLlama-13b-Instruct", - api_key: "<TOGETHER_API_KEY>", - }, - }, - { title: "llama.cpp", class: "LlamaCpp", args: {}, }, { - title: "gpt-4 (limited free usage)", + title: "Other OpenAI-compatible API", + class: "GGML", + args: { + server_url: "<SERVER_URL>", + }, + }, + { + title: "Continue Free Trial (gpt-4)", class: "MaybeProxyOpenAI", args: { model: "gpt-4", @@ -70,21 +75,56 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ }, ]; +const GridDiv = styled.div` + display: grid; + grid-template-columns: 1fr auto; + align-items: center; + border: 0.5px solid ${lightGray}; + border-radius: ${defaultBorderRadius}; + overflow: hidden; +`; + const Select = styled.select` border: none; - width: 25vw; - background-color: ${secondaryDark}; + max-width: 25vw; + background-color: ${vscBackground}; color: ${vscForeground}; - border-radius: ${defaultBorderRadius}; padding: 6px; max-height: 35vh; overflow: scroll; cursor: pointer; - margin-right: auto; &:focus { outline: none; } + &:hover { + background-color: ${secondaryDark}; + } +`; + +const StyledPlusIcon = styled(PlusIcon)` + cursor: pointer; + margin: 0px; + padding-left: 4px; + padding-right: 4px; + height: 100%; + + &:hover { + background-color: ${secondaryDark}; + } + border-left: 0.5px solid ${lightGray}; +`; + +const NewProviderDiv = styled.div` + cursor: pointer; + padding: 8px; + padding-left: 16px; + padding-right: 16px; + border-top: 0.5px solid ${lightGray}; + + &:hover { + background-color: ${secondaryDark}; + } `; function modelSelectTitle(model: any): string { @@ -99,6 +139,7 @@ function modelSelectTitle(model: any): string { } function ModelSelect(props: {}) { + const dispatch = useDispatch(); const client = useContext(GUIClientContext); const defaultModel = useSelector( (state: RootStore) => (state.serverState.config as any)?.models?.default @@ -108,23 +149,20 @@ function ModelSelect(props: {}) { ); return ( - <Select - value={JSON.stringify({ - t: "default", - idx: -1, - })} - defaultValue={0} - onChange={(e) => { - const value = JSON.parse(e.target.value); - if (value.t === "unused") { - client?.setModelForRoleFromIndex("*", value.idx); - } else if (value.t === "new") { - const model = MODEL_INFO[value.idx]; - client?.addModelForRole("*", model.class, model.args); - } - }} - > - <optgroup label="My Saved Models"> + <GridDiv> + <Select + value={JSON.stringify({ + t: "default", + idx: -1, + })} + defaultValue={0} + onChange={(e) => { + const value = JSON.parse(e.target.value); + if (value.t === "unused") { + client?.setModelForRoleFromIndex("*", value.idx); + } + }} + > {defaultModel && ( <option value={JSON.stringify({ @@ -147,22 +185,40 @@ function ModelSelect(props: {}) { </option> ); })} - </optgroup> - <optgroup label="Add New Model"> - {MODEL_INFO.map((model, idx) => { - return ( - <option - value={JSON.stringify({ - t: "new", - idx, - })} - > - {model.title} - </option> + </Select> + + <StyledPlusIcon + width="1.3em" + height="1.3em" + onClick={() => { + dispatch( + setDialogMessage( + <div> + <div className="text-lg font-bold p-2"> + Setup a new model provider + </div> + <br /> + {MODEL_INFO.map((model, idx) => { + return ( + <NewProviderDiv + onClick={() => { + const model = MODEL_INFO[idx]; + client?.addModelForRole("*", model.class, model.args); + dispatch(setShowDialog(false)); + }} + > + {model.title} + </NewProviderDiv> + ); + })} + <br /> + </div> + ) ); - })} - </optgroup> - </Select> + dispatch(setShowDialog(true)); + }} + /> + </GridDiv> ); } |