summaryrefslogtreecommitdiff
path: root/extension/react-app/src/pages/modelconfig.tsx
diff options
context:
space:
mode:
Diffstat (limited to 'extension/react-app/src/pages/modelconfig.tsx')
-rw-r--r--extension/react-app/src/pages/modelconfig.tsx63
1 files changed, 42 insertions, 21 deletions
diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx
index 97e2d76c..00d9d9bf 100644
--- a/extension/react-app/src/pages/modelconfig.tsx
+++ b/extension/react-app/src/pages/modelconfig.tsx
@@ -3,7 +3,7 @@ import ModelCard from "../components/ModelCard";
import styled from "styled-components";
import { ArrowLeftIcon } from "@heroicons/react/24/outline";
import {
- TextInput,
+ Input,
defaultBorderRadius,
lightGray,
vscBackground,
@@ -22,6 +22,7 @@ import { RootStore } from "../redux/store";
import StyledMarkdownPreview from "../components/StyledMarkdownPreview";
import { getFontSize } from "../util";
import { FormProvider, useForm } from "react-hook-form";
+import _ from "lodash";
const GridDiv = styled.div`
display: grid;
@@ -151,22 +152,28 @@ function ModelConfig() {
<>
<h3 className="mb-2">Enter required parameters</h3>
- {modelInfo?.collectInputFor?.map((d) => {
- return (
- <div>
- <label htmlFor={d.key}>{d.key}</label>
- <TextInput
- id={d.key}
- className="border-2 border-gray-200 rounded-md p-2 m-2"
- placeholder={d.key}
- defaultValue={d.defaultValue}
- {...formMethods.register(d.key, {
- required: true,
- })}
- />
- </div>
- );
- })}
+ {modelInfo?.collectInputFor
+ ?.filter((d) => d.required)
+ .map((d) => {
+ return (
+ <div>
+ <label htmlFor={d.key}>{d.key}</label>
+ <Input
+ type={d.inputType}
+ id={d.key}
+ className="border-2 border-gray-200 rounded-md p-2 m-2"
+ placeholder={d.key}
+ defaultValue={d.defaultValue}
+ min={d.min}
+ max={d.max}
+ step={d.step}
+ {...formMethods.register(d.key, {
+ required: true,
+ })}
+ />
+ </div>
+ );
+ })}
</>
)}
@@ -182,11 +189,15 @@ function ModelConfig() {
return (
<div>
<label htmlFor={d.key}>{d.key}</label>
- <TextInput
+ <Input
+ type={d.inputType}
id={d.key}
className="border-2 border-gray-200 rounded-md p-2 m-2"
placeholder={d.key}
defaultValue={d.defaultValue}
+ min={d.min}
+ max={d.max}
+ step={d.step}
{...formMethods.register(d.key, {
required: false,
})}
@@ -209,19 +220,29 @@ function ModelConfig() {
tags={pkg.tags}
refUrl={pkg.refUrl}
icon={pkg.icon || modelInfo.icon}
- onClick={(e) => {
+ dimensions={pkg.dimensions}
+ onClick={(e, dimensionChoices) => {
if (disableModelCards()) return;
const formParams: any = {};
for (const d of modelInfo.collectInputFor || []) {
formParams[d.key] =
d.inputType === "text"
? formMethods.watch(d.key)
- : parseInt(formMethods.watch(d.key));
+ : parseFloat(formMethods.watch(d.key));
}
client?.addModelForRole("*", modelInfo.class, {
...pkg.params,
...modelInfo.params,
+ ..._.merge(
+ {},
+ ...(pkg.dimensions?.map((dimension, i) => {
+ if (!dimensionChoices?.[i]) return {};
+ return {
+ ...dimension.options[dimensionChoices[i]],
+ };
+ }) || [])
+ ),
...formParams,
});
navigate("/");
@@ -239,7 +260,7 @@ function ModelConfig() {
formParams[d.key] =
d.inputType === "text"
? formMethods.watch(d.key)
- : parseInt(formMethods.watch(d.key));
+ : parseFloat(formMethods.watch(d.key));
}
client?.addModelForRole("*", modelInfo.class, {