summaryrefslogtreecommitdiff
path: root/extension/react-app/src/pages
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-29 20:20:45 -0700
committerGitHub <noreply@github.com>2023-09-29 20:20:45 -0700
commit0dfdd4c52a9d686af54346ade35e0bcff226c8b9 (patch)
treed4f98c7809ddfc7ed14e3be36fe921cc418a8917 /extension/react-app/src/pages
parent64558321addcc80de9137cf9c9ef1bf7ed85ffa5 (diff)
downloadsncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.gz
sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.tar.bz2
sncontinue-0dfdd4c52a9d686af54346ade35e0bcff226c8b9.zip
Model config UI (#522)
* feat: :sparkles: improved model selection * feat: :sparkles: add max_tokens option to LLM class * docs: :memo: update reference with max_tokens * feat: :loud_sound: add context to dev data loggign * feat: :sparkles: final work on model config ui
Diffstat (limited to 'extension/react-app/src/pages')
-rw-r--r--extension/react-app/src/pages/modelconfig.tsx261
-rw-r--r--extension/react-app/src/pages/models.tsx152
2 files changed, 286 insertions, 127 deletions
diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx
new file mode 100644
index 00000000..97e2d76c
--- /dev/null
+++ b/extension/react-app/src/pages/modelconfig.tsx
@@ -0,0 +1,261 @@
+import React, { useCallback, useContext, useEffect, useState } from "react";
+import ModelCard from "../components/ModelCard";
+import styled from "styled-components";
+import { ArrowLeftIcon } from "@heroicons/react/24/outline";
+import {
+ TextInput,
+ defaultBorderRadius,
+ lightGray,
+ vscBackground,
+} from "../components";
+import { Form, useNavigate } from "react-router-dom";
+import { useDispatch, useSelector } from "react-redux";
+import { GUIClientContext } from "../App";
+import { setShowDialog } from "../redux/slices/uiStateSlice";
+import { useParams } from "react-router-dom";
+import {
+ MODEL_INFO,
+ MODEL_PROVIDER_TAG_COLORS,
+ ModelInfo,
+} from "../util/modelData";
+import { RootStore } from "../redux/store";
+import StyledMarkdownPreview from "../components/StyledMarkdownPreview";
+import { getFontSize } from "../util";
+import { FormProvider, useForm } from "react-hook-form";
+
+const GridDiv = styled.div`
+ display: grid;
+ grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
+ grid-gap: 2rem;
+ padding: 1rem;
+ justify-items: center;
+ align-items: center;
+`;
+
+const CustomModelButton = styled.div<{ disabled: boolean }>`
+ border: 1px solid ${lightGray};
+ border-radius: ${defaultBorderRadius};
+ padding: 4px 8px;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ width: 100%;
+ transition: all 0.5s;
+
+ ${(props) =>
+ props.disabled
+ ? `
+ opacity: 0.5;
+ `
+ : `
+ &:hover {
+ border: 1px solid #be1b55;
+ background-color: #be1b5522;
+ cursor: pointer;
+ }
+ `}
+`;
+
+function ModelConfig() {
+ const formMethods = useForm();
+ const { modelName } = useParams();
+
+ const [modelInfo, setModelInfo] = useState<ModelInfo | undefined>(undefined);
+
+ useEffect(() => {
+ if (modelName) {
+ setModelInfo(MODEL_INFO[modelName]);
+ }
+ }, [modelName]);
+
+ const client = useContext(GUIClientContext);
+ const dispatch = useDispatch();
+ const navigate = useNavigate();
+ const vscMediaUrl = useSelector(
+ (state: RootStore) => state.config.vscMediaUrl
+ );
+
+ const disableModelCards = useCallback(() => {
+ return (
+ modelInfo?.collectInputFor?.some((d) => {
+ if (!d.required) return false;
+ const val = formMethods.watch(d.key);
+ return (
+ typeof val === "undefined" || (typeof val === "string" && val === "")
+ );
+ }) || false
+ );
+ }, [modelInfo, formMethods]);
+
+ return (
+ <FormProvider {...formMethods}>
+ <div className="overflow-y-scroll">
+ <div
+ className="items-center flex m-0 p-0 sticky top-0"
+ style={{
+ borderBottom: `0.5px solid ${lightGray}`,
+ backgroundColor: vscBackground,
+ zIndex: 2,
+ }}
+ >
+ <ArrowLeftIcon
+ width="1.2em"
+ height="1.2em"
+ onClick={() => navigate("/models")}
+ className="inline-block ml-4 cursor-pointer"
+ />
+ <h3 className="text-lg font-bold m-2 inline-block">
+ Configure Model
+ </h3>
+ </div>
+
+ <div className="px-2">
+ <div style={{ display: "flex", alignItems: "center" }}>
+ {vscMediaUrl && (
+ <img
+ src={`${vscMediaUrl}/logos/${modelInfo?.icon}`}
+ height="24px"
+ style={{ marginRight: "10px" }}
+ />
+ )}
+ <h2>{modelInfo?.title}</h2>
+ </div>
+ {modelInfo?.tags?.map((tag) => {
+ return (
+ <span
+ style={{
+ backgroundColor: `${MODEL_PROVIDER_TAG_COLORS[tag]}55`,
+ color: "white",
+ padding: "2px 4px",
+ borderRadius: defaultBorderRadius,
+ marginRight: "4px",
+ }}
+ >
+ {tag}
+ </span>
+ );
+ })}
+ <StyledMarkdownPreview
+ className="mt-2"
+ fontSize={getFontSize()}
+ source={modelInfo?.longDescription || modelInfo?.description || ""}
+ wrapperElement={{
+ "data-color-mode": "dark",
+ }}
+ maxHeight={200}
+ />
+ <br />
+
+ {(modelInfo?.collectInputFor?.filter((d) => d.required).length || 0) >
+ 0 && (
+ <>
+ <h3 className="mb-2">Enter required parameters</h3>
+
+ {modelInfo?.collectInputFor?.map((d) => {
+ return (
+ <div>
+ <label htmlFor={d.key}>{d.key}</label>
+ <TextInput
+ id={d.key}
+ className="border-2 border-gray-200 rounded-md p-2 m-2"
+ placeholder={d.key}
+ defaultValue={d.defaultValue}
+ {...formMethods.register(d.key, {
+ required: true,
+ })}
+ />
+ </div>
+ );
+ })}
+ </>
+ )}
+
+ {(modelInfo?.collectInputFor?.filter((d) => !d.required).length ||
+ 0) > 0 && (
+ <details>
+ <summary className="mb-2">
+ <b>Advanced (optional)</b>
+ </summary>
+
+ {modelInfo?.collectInputFor?.map((d) => {
+ if (d.required) return null;
+ return (
+ <div>
+ <label htmlFor={d.key}>{d.key}</label>
+ <TextInput
+ id={d.key}
+ className="border-2 border-gray-200 rounded-md p-2 m-2"
+ placeholder={d.key}
+ defaultValue={d.defaultValue}
+ {...formMethods.register(d.key, {
+ required: false,
+ })}
+ />
+ </div>
+ );
+ })}
+ </details>
+ )}
+
+ <h3 className="mb-2">Select a model preset</h3>
+ </div>
+ <GridDiv>
+ {modelInfo?.packages.map((pkg) => {
+ return (
+ <ModelCard
+ disabled={disableModelCards()}
+ title={pkg.title}
+ description={pkg.description}
+ tags={pkg.tags}
+ refUrl={pkg.refUrl}
+ icon={pkg.icon || modelInfo.icon}
+ onClick={(e) => {
+ if (disableModelCards()) return;
+ const formParams: any = {};
+ for (const d of modelInfo.collectInputFor || []) {
+ formParams[d.key] =
+ d.inputType === "text"
+ ? formMethods.watch(d.key)
+ : parseInt(formMethods.watch(d.key));
+ }
+
+ client?.addModelForRole("*", modelInfo.class, {
+ ...pkg.params,
+ ...modelInfo.params,
+ ...formParams,
+ });
+ navigate("/");
+ }}
+ />
+ );
+ })}
+
+ <CustomModelButton
+ disabled={disableModelCards()}
+ onClick={(e) => {
+ if (!modelInfo || disableModelCards()) return;
+ const formParams: any = {};
+ for (const d of modelInfo.collectInputFor || []) {
+ formParams[d.key] =
+ d.inputType === "text"
+ ? formMethods.watch(d.key)
+ : parseInt(formMethods.watch(d.key));
+ }
+
+ client?.addModelForRole("*", modelInfo.class, {
+ ...modelInfo.packages[0]?.params,
+ ...modelInfo.params,
+ ...formParams,
+ });
+ navigate("/");
+ }}
+ >
+ <h3 className="text-center my-2">Configure Model in config.py</h3>
+ </CustomModelButton>
+ </GridDiv>
+ </div>
+ </FormProvider>
+ );
+}
+
+export default ModelConfig;
diff --git a/extension/react-app/src/pages/models.tsx b/extension/react-app/src/pages/models.tsx
index c20d820c..a9a97a13 100644
--- a/extension/react-app/src/pages/models.tsx
+++ b/extension/react-app/src/pages/models.tsx
@@ -1,131 +1,13 @@
-import React from "react";
-import ModelCard, { ModelInfo, ModelTag } from "../components/ModelCard";
+import React, { useContext } from "react";
+import ModelCard from "../components/ModelCard";
import styled from "styled-components";
import { ArrowLeftIcon } from "@heroicons/react/24/outline";
import { lightGray, vscBackground } from "../components";
import { useNavigate } from "react-router-dom";
-
-const MODEL_INFO: ModelInfo[] = [
- {
- title: "OpenAI",
- class: "OpenAI",
- description: "Use gpt-4, gpt-3.5-turbo, or any other OpenAI model",
- args: {
- model: "gpt-4",
- api_key: "",
- title: "OpenAI",
- },
- icon: "openai.svg",
- tags: [ModelTag["Requires API Key"]],
- },
- {
- title: "Anthropic",
- class: "AnthropicLLM",
- description:
- "Claude-2 is a highly capable model with a 100k context length",
- args: {
- model: "claude-2",
- api_key: "<ANTHROPIC_API_KEY>",
- title: "Anthropic",
- },
- icon: "anthropic.png",
- tags: [ModelTag["Requires API Key"]],
- },
- {
- title: "Ollama",
- class: "Ollama",
- description:
- "One of the fastest ways to get started with local models on Mac or Linux",
- args: {
- model: "codellama",
- title: "Ollama",
- },
- icon: "ollama.png",
- tags: [ModelTag["Local"], ModelTag["Open-Source"]],
- },
- {
- title: "TogetherAI",
- class: "TogetherLLM",
- description:
- "Use the TogetherAI API for extremely fast streaming of open-source models",
- args: {
- model: "togethercomputer/CodeLlama-13b-Instruct",
- api_key: "<TOGETHER_API_KEY>",
- title: "TogetherAI",
- },
- icon: "together.png",
- tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]],
- },
- {
- title: "LM Studio",
- class: "GGML",
- description:
- "One of the fastest ways to get started with local models on Mac or Windows",
- args: {
- server_url: "http://localhost:1234",
- title: "LM Studio",
- },
- icon: "lmstudio.png",
- tags: [ModelTag["Local"], ModelTag["Open-Source"]],
- },
- {
- title: "Replicate",
- class: "ReplicateLLM",
- description: "Use the Replicate API to run open-source models",
- args: {
- model:
- "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781",
- api_key: "<REPLICATE_API_KEY>",
- title: "Replicate",
- },
- icon: "replicate.png",
- tags: [ModelTag["Requires API Key"], ModelTag["Open-Source"]],
- },
- {
- title: "llama.cpp",
- class: "LlamaCpp",
- description: "If you are running the llama.cpp server from source",
- args: {
- title: "llama.cpp",
- },
- icon: "llamacpp.png",
- tags: [ModelTag.Local, ModelTag["Open-Source"]],
- },
- {
- title: "HuggingFace TGI",
- class: "HuggingFaceTGI",
- description:
- "HuggingFace Text Generation Inference is an advanced, highly performant option for serving open-source models to multiple people",
- args: {
- title: "HuggingFace TGI",
- },
- icon: "hf.png",
- tags: [ModelTag.Local, ModelTag["Open-Source"]],
- },
- {
- title: "Other OpenAI-compatible API",
- class: "GGML",
- description:
- "If you are using any other OpenAI-compatible API, for example text-gen-webui, FastChat, LocalAI, or llama-cpp-python, you can simply enter your server URL",
- args: {
- server_url: "<SERVER_URL>",
- },
- icon: "openai.svg",
- tags: [ModelTag.Local, ModelTag["Open-Source"]],
- },
- {
- title: "GPT-4 limited free trial",
- class: "OpenAIFreeTrial",
- description:
- "New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key",
- args: {
- model: "gpt-4",
- title: "GPT-4 Free Trial",
- },
- icon: "openai.svg",
- tags: [ModelTag.Free],
- },
-];
+import { useDispatch } from "react-redux";
+import { GUIClientContext } from "../App";
+import { setShowDialog } from "../redux/slices/uiStateSlice";
+import { MODEL_INFO } from "../util/modelData";
const GridDiv = styled.div`
display: grid;
@@ -138,6 +20,8 @@ const GridDiv = styled.div`
function Models() {
const navigate = useNavigate();
+ const client = useContext(GUIClientContext);
+ const dispatch = useDispatch();
return (
<div className="overflow-y-scroll">
<div
@@ -154,11 +38,25 @@ function Models() {
onClick={() => navigate("/")}
className="inline-block ml-4 cursor-pointer"
/>
- <h3 className="text-lg font-bold m-2 inline-block">Add a new model</h3>
+ <h3 className="text-lg font-bold m-2 inline-block">
+ Select LLM Provider
+ </h3>
</div>
<GridDiv>
- {MODEL_INFO.map((model) => (
- <ModelCard modelInfo={model} />
+ {Object.entries(MODEL_INFO).map(([name, modelInfo]) => (
+ <ModelCard
+ title={modelInfo.title}
+ description={modelInfo.description}
+ tags={modelInfo.tags}
+ icon={modelInfo.icon}
+ refUrl={`https://continue.dev/docs/reference/Models/${modelInfo.class.toLowerCase()}`}
+ onClick={(e) => {
+ if ((e.target as any).closest("a")) {
+ return;
+ }
+ navigate(`/modelconfig/${name}`);
+ }}
+ />
))}
</GridDiv>
</div>