diff options
Diffstat (limited to 'extension/react-app')
19 files changed, 613 insertions, 218 deletions
| diff --git a/extension/react-app/package-lock.json b/extension/react-app/package-lock.json index fb68081c..a52396ef 100644 --- a/extension/react-app/package-lock.json +++ b/extension/react-app/package-lock.json @@ -11,6 +11,7 @@          "@types/vscode-webview": "^1.57.1",          "@uiw/react-markdown-preview": "^4.1.13",          "downshift": "^7.6.0", +        "lodash": "^4.17.21",          "meilisearch": "^0.33.0",          "posthog-js": "^1.58.0",          "prismjs": "^1.29.0", diff --git a/extension/react-app/package.json b/extension/react-app/package.json index b9f70645..be23b34b 100644 --- a/extension/react-app/package.json +++ b/extension/react-app/package.json @@ -12,6 +12,7 @@      "@types/vscode-webview": "^1.57.1",      "@uiw/react-markdown-preview": "^4.1.13",      "downshift": "^7.6.0", +    "lodash": "^4.17.21",      "meilisearch": "^0.33.0",      "posthog-js": "^1.58.0",      "prismjs": "^1.29.0", diff --git a/extension/react-app/public/logos/mistral.png b/extension/react-app/public/logos/mistral.pngBinary files differ new file mode 100644 index 00000000..0f535f84 --- /dev/null +++ b/extension/react-app/public/logos/mistral.png diff --git a/extension/react-app/public/logos/wizardlm.png b/extension/react-app/public/logos/wizardlm.pngBinary files differ new file mode 100644 index 00000000..a420cf03 --- /dev/null +++ b/extension/react-app/public/logos/wizardlm.png diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index c08c05de..1d0ca1a5 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -285,15 +285,13 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {    useEffect(() => {      if (!inputRef.current) return; -    if (inputRef.current.scrollHeight > inputRef.current.clientHeight) { -      inputRef.current.style.height = "auto"; -      inputRef.current.style.height = -        Math.min(inputRef.current.scrollHeight, 300) + "px"; -    } +    inputRef.current.style.height = "auto"; +    inputRef.current.style.height = +      Math.min(inputRef.current.scrollHeight, 300) + "px";    }, [      inputRef.current?.scrollHeight,      inputRef.current?.clientHeight, -    props.value, +    inputRef.current?.value,    ]);    // Whether the current input follows an '@' and should be treated as context query @@ -344,7 +342,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {    useEffect(() => {      if (!nestedContextProvider) { -      dispatch(setTakenActionTrue(null));        setItems(          contextProviders?.map((provider) => ({            name: provider.display_title, @@ -437,7 +434,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {        setNestedContextProvider(undefined);        // Handle slash commands -      dispatch(setTakenActionTrue(null));        setItems(          availableSlashCommands?.filter((slashCommand) => {            const sc = slashCommand.name.toLowerCase(); @@ -445,6 +441,10 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {            return sc.startsWith(iv) && sc !== iv;          }) || []        ); + +      if (inputValue.startsWith("/") || inputValue.startsWith("@")) { +        dispatch(setTakenActionTrue(null)); +      }      },      [        availableSlashCommands, @@ -756,6 +756,8 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {                  props.index                );                inputRef.current?.focus(); +              setPreviewingContextItem(undefined); +              setFocusedContextItem(undefined);              }}              onKeyDown={(e: any) => {                if (e.key === "Backspace") { @@ -880,6 +882,7 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {                paddingLeft: "12px",                cursor: "default",                paddingTop: getFontSize(), +              width: "fit-content",              }}            >              {props.active ? "Using" : "Used"} {selectedContextItems.length}{" "} @@ -937,17 +940,7 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {              {...getInputProps({                onCompositionStart: () => setIsComposing(true),                onCompositionEnd: () => setIsComposing(false), -              onChange: (e) => { -                const target = e.target as HTMLTextAreaElement; -                // Update the height of the textarea to match the content, up to a max of 200px. -                target.style.height = "auto"; -                target.style.height = `${Math.min( -                  target.scrollHeight, -                  300 -                ).toString()}px`; - -                // setShowContextDropdown(target.value.endsWith("@")); -              }, +              onChange: (e) => {},                onFocus: (e) => {                  setInputFocused(true);                  dispatch(setBottomMessage(undefined)); diff --git a/extension/react-app/src/components/ErrorStepContainer.tsx b/extension/react-app/src/components/ErrorStepContainer.tsx index 666780c5..07c0a046 100644 --- a/extension/react-app/src/components/ErrorStepContainer.tsx +++ b/extension/react-app/src/components/ErrorStepContainer.tsx @@ -42,7 +42,7 @@ function ErrorStepContainer(props: ErrorStepContainerProps) {          </HeaderButtonWithText>        </div>        <Div> -        <pre className="overflow-x-scroll"> +        <pre style={{ whiteSpace: "pre-wrap", wordWrap: "break-word" }}>            {props.historyNode.observation?.error as string}          </pre>        </Div> diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx index a54c0ed4..db31c8db 100644 --- a/extension/react-app/src/components/Layout.tsx +++ b/extension/react-app/src/components/Layout.tsx @@ -30,6 +30,20 @@ const LayoutTopDiv = styled.div`    border-radius: ${defaultBorderRadius};    scrollbar-base-color: transparent;    scrollbar-width: thin; + +  & * { +    ::-webkit-scrollbar { +      width: 4px; +    } + +    ::-webkit-scrollbar:horizontal { +      height: 4px; +    } + +    ::-webkit-scrollbar-thumb { +      border-radius: 2px; +    } +  }  `;  const BottomMessageDiv = styled.div<{ displayOnBottom: boolean }>` @@ -47,7 +61,6 @@ const BottomMessageDiv = styled.div<{ displayOnBottom: boolean }>`    z-index: 100;    box-shadow: 0px 0px 2px 0px ${vscForeground};    max-height: 35vh; -  overflow: scroll;  `;  const Footer = styled.footer` @@ -131,6 +144,20 @@ const Layout = () => {      };    }, [client, timeline]); +  useEffect(() => { +    const handler = (event: any) => { +      if (event.data.type === "addModel") { +        navigate("/models"); +      } else if (event.data.type === "openSettings") { +        navigate("/settings"); +      } +    }; +    window.addEventListener("message", handler); +    return () => { +      window.removeEventListener("message", handler); +    }; +  }, []); +    return (      <LayoutTopDiv>        <div diff --git a/extension/react-app/src/components/ModelCard.tsx b/extension/react-app/src/components/ModelCard.tsx index d1cb3165..0ab6ac32 100644 --- a/extension/react-app/src/components/ModelCard.tsx +++ b/extension/react-app/src/components/ModelCard.tsx @@ -1,16 +1,16 @@ -import React, { useContext } from "react"; +import React, { useContext, useState } from "react";  import styled from "styled-components";  import { buttonColor, defaultBorderRadius, lightGray } from ".";  import { useSelector } from "react-redux";  import { RootStore } from "../redux/store";  import { BookOpenIcon } from "@heroicons/react/24/outline";  import HeaderButtonWithText from "./HeaderButtonWithText"; -import { MODEL_PROVIDER_TAG_COLORS } from "../util/modelData"; +import { MODEL_PROVIDER_TAG_COLORS, PackageDimension } from "../util/modelData"; +import InfoHover from "./InfoHover"; -const Div = styled.div<{ color: string; disabled: boolean }>` +const Div = styled.div<{ color: string; disabled: boolean; hovered: boolean }>`    border: 1px solid ${lightGray};    border-radius: ${defaultBorderRadius}; -  padding: 4px 8px;    position: relative;    width: 100%;    transition: all 0.5s; @@ -20,13 +20,45 @@ const Div = styled.div<{ color: string; disabled: boolean }>`        ? `      opacity: 0.5;      ` -      : ` -  &:hover { +      : props.hovered +      ? `      border: 1px solid ${props.color};      background-color: ${props.color}22; +    cursor: pointer;` +      : ""} +`; + +const DimensionsDiv = styled.div` +  display: flex; +  justify-content: flex-end; +  margin-left: auto; +  padding: 4px; +  /* width: fit-content; */ + +  border-top: 1px solid ${lightGray}; +`; + +const DimensionOptionDiv = styled.div<{ selected: boolean }>` +  display: flex; +  flex-direction: column; +  align-items: center; +  margin-right: 8px; +  background-color: ${lightGray}; +  padding: 4px; +  border-radius: ${defaultBorderRadius}; +  outline: 0.5px solid ${lightGray}; + +  ${(props) => +    props.selected && +    ` +    background-color: ${buttonColor}; +    color: white; +  `} + +  &:hover {      cursor: pointer; +    outline: 1px solid ${buttonColor};    } -  `}  `;  interface ModelCardProps { @@ -35,8 +67,12 @@ interface ModelCardProps {    tags?: string[];    refUrl?: string;    icon?: string; -  onClick: (e: React.MouseEvent<HTMLDivElement, MouseEvent>) => void; +  onClick: ( +    e: React.MouseEvent<HTMLDivElement, MouseEvent>, +    dimensionChoices?: string[] +  ) => void;    disabled?: boolean; +  dimensions?: PackageDimension[];  }  function ModelCard(props: ModelCardProps) { @@ -44,53 +80,103 @@ function ModelCard(props: ModelCardProps) {      (state: RootStore) => state.config.vscMediaUrl    ); +  const [dimensionChoices, setDimensionChoices] = useState<string[]>( +    props.dimensions?.map((d) => Object.keys(d.options)[0]) || [] +  ); + +  const [hovered, setHovered] = useState(false); +    return (      <Div        disabled={props.disabled || false}        color={buttonColor} -      onClick={props.disabled ? undefined : (e) => props.onClick(e)} +      hovered={hovered}      > -      <div style={{ display: "flex", alignItems: "center" }}> -        {vscMediaUrl && props.icon && ( -          <img -            src={`${vscMediaUrl}/logos/${props.icon}`} -            height="24px" -            style={{ marginRight: "10px" }} -          /> -        )} -        <h3>{props.title}</h3> -      </div> -      {props.tags?.map((tag) => { -        return ( -          <span +      <div +        onMouseEnter={() => setHovered(true)} +        onMouseLeave={() => setHovered(false)} +        className="px-2 py-1" +        onClick={ +          props.disabled +            ? undefined +            : (e) => { +                if ((e.target as any).closest("a")) { +                  return; +                } +                props.onClick(e, dimensionChoices); +              } +        } +      > +        <div style={{ display: "flex", alignItems: "center" }}> +          {vscMediaUrl && props.icon && ( +            <img +              src={`${vscMediaUrl}/logos/${props.icon}`} +              height="24px" +              style={{ marginRight: "10px" }} +            /> +          )} +          <h3>{props.title}</h3> +        </div> +        {props.tags?.map((tag) => { +          return ( +            <span +              style={{ +                backgroundColor: `${MODEL_PROVIDER_TAG_COLORS[tag]}55`, +                color: "white", +                padding: "2px 4px", +                borderRadius: defaultBorderRadius, +                marginRight: "4px", +              }} +            > +              {tag} +            </span> +          ); +        })} +        <p>{props.description}</p> + +        {props.refUrl && ( +          <a              style={{ -              backgroundColor: `${MODEL_PROVIDER_TAG_COLORS[tag]}55`, -              color: "white", -              padding: "2px 4px", -              borderRadius: defaultBorderRadius, -              marginRight: "4px", +              position: "absolute", +              right: "8px", +              top: "8px",              }} +            href={props.refUrl} +            target="_blank"            > -            {tag} -          </span> -        ); -      })} -      <p>{props.description}</p> +            <HeaderButtonWithText text="Read the docs"> +              <BookOpenIcon width="1.6em" height="1.6em" /> +            </HeaderButtonWithText> +          </a> +        )} +      </div> -      {props.refUrl && ( -        <a -          style={{ -            position: "absolute", -            right: "8px", -            top: "8px", -          }} -          href={props.refUrl} -          target="_blank" -        > -          <HeaderButtonWithText text="Read the docs"> -            <BookOpenIcon width="1.6em" height="1.6em" /> -          </HeaderButtonWithText> -        </a> +      {props.dimensions?.length && ( +        <DimensionsDiv> +          {props.dimensions?.map((dimension, i) => { +            return ( +              <div className="flex items-center"> +                <InfoHover msg={dimension.description} /> +                <p className="mx-2 text-sm my-0 py-0">{dimension.name}</p> +                {Object.keys(dimension.options).map((key) => { +                  return ( +                    <DimensionOptionDiv +                      onClick={(e) => { +                        e.stopPropagation(); +                        const newChoices = [...dimensionChoices]; +                        newChoices[i] = key; +                        setDimensionChoices(newChoices); +                      }} +                      selected={dimensionChoices[i] === key} +                    > +                      {key} +                    </DimensionOptionDiv> +                  ); +                })} +              </div> +            ); +          })} +        </DimensionsDiv>        )}      </Div>    ); diff --git a/extension/react-app/src/components/ModelSettings.tsx b/extension/react-app/src/components/ModelSettings.tsx index 4b9d5e64..3f9414b1 100644 --- a/extension/react-app/src/components/ModelSettings.tsx +++ b/extension/react-app/src/components/ModelSettings.tsx @@ -3,7 +3,7 @@ import { LLM } from "../../../schema/LLM";  import {    Label,    Select, -  TextInput, +  Input,    defaultBorderRadius,    lightGray,    vscForeground, @@ -58,7 +58,7 @@ function ModelSettings(props: { llm: any | undefined; role: string }) {              {typeof modelOptions.api_key !== undefined && (                <>                  <Label fontSize={getFontSize()}>API Key</Label> -                <TextInput +                <Input                    type="text"                    defaultValue={props.llm.api_key}                    placeholder="API Key" @@ -69,7 +69,7 @@ function ModelSettings(props: { llm: any | undefined; role: string }) {              {modelOptions.model && (                <>                  <Label fontSize={getFontSize()}>Model</Label> -                <TextInput +                <Input                    type="text"                    defaultValue={props.llm.model}                    placeholder="Model" diff --git a/extension/react-app/src/components/Suggestions.tsx b/extension/react-app/src/components/Suggestions.tsx index bdda7579..5779eea8 100644 --- a/extension/react-app/src/components/Suggestions.tsx +++ b/extension/react-app/src/components/Suggestions.tsx @@ -16,6 +16,7 @@ import { useSelector } from "react-redux";  import { RootStore } from "../redux/store";  import HeaderButtonWithText from "./HeaderButtonWithText";  import { getFontSize } from "../util"; +import { usePostHog } from "posthog-js/react";  const Div = styled.div<{ isDisabled: boolean }>`    border-radius: ${defaultBorderRadius}; @@ -159,6 +160,7 @@ const TutorialDiv = styled.div`  `;  function SuggestionsArea(props: { onClick: (textInput: string) => void }) { +  const posthog = usePostHog();    const [stage, setStage] = useState(      parseInt(localStorage.getItem("stage") || "0")    ); @@ -207,8 +209,18 @@ function SuggestionsArea(props: { onClick: (textInput: string) => void }) {              className="absolute right-1 top-1 cursor-pointer"              text="Close Tutorial"              onClick={() => { -              console.log("HIDE");                setHide(true); +              const tutorialClosedCount = parseInt( +                localStorage.getItem("tutorialClosedCount") || "0" +              ); +              localStorage.setItem( +                "tutorialClosedCount", +                (tutorialClosedCount + 1).toString() +              ); +              posthog?.capture("tutorial_closed", { +                stage, +                tutorialClosedCount, +              });              }}            >              <XMarkIcon width="1.2em" height="1.2em" /> @@ -219,8 +231,9 @@ function SuggestionsArea(props: { onClick: (textInput: string) => void }) {                  disabled={!codeIsHighlighted}                  {...suggestion}                  onClick={() => { -                  if (stage > 0 && !codeIsHighlighted) return; +                  if (!codeIsHighlighted) return;                    props.onClick(suggestion.textInput); +                  posthog?.capture("tutorial_stage_complete", { stage });                    setStage(stage + 1);                    localStorage.setItem("stage", (stage + 1).toString());                    setHide(true); diff --git a/extension/react-app/src/components/dialogs/AddContextGroupDialog.tsx b/extension/react-app/src/components/dialogs/AddContextGroupDialog.tsx index 9cd0a95e..a6cf151c 100644 --- a/extension/react-app/src/components/dialogs/AddContextGroupDialog.tsx +++ b/extension/react-app/src/components/dialogs/AddContextGroupDialog.tsx @@ -1,5 +1,5 @@  import { useContext } from "react"; -import { Button, TextInput } from ".."; +import { Button, Input } from "..";  import { GUIClientContext } from "../../App";  import { useDispatch } from "react-redux";  import { @@ -27,7 +27,7 @@ function AddContextGroupDialog({    return (      <div className="p-4"> -      <TextInput +      <Input          defaultValue="My Context Group"          type="text"          ref={(input) => { diff --git a/extension/react-app/src/components/dialogs/FTCDialog.tsx b/extension/react-app/src/components/dialogs/FTCDialog.tsx index 3ea753bc..5fa2d4e6 100644 --- a/extension/react-app/src/components/dialogs/FTCDialog.tsx +++ b/extension/react-app/src/components/dialogs/FTCDialog.tsx @@ -1,6 +1,6 @@  import React, { useContext } from "react";  import styled from "styled-components"; -import { Button, TextInput } from ".."; +import { Button, Input } from "..";  import { useNavigate } from "react-router-dom";  import { GUIClientContext } from "../../App";  import { useDispatch } from "react-redux"; @@ -37,7 +37,7 @@ function FTCDialog() {          OpenAIFreeTrial object.        </p> -      <TextInput +      <Input          type="text"          placeholder="Enter your OpenAI API key"          value={apiKey} @@ -46,6 +46,7 @@ function FTCDialog() {        <GridDiv>          <Button            onClick={() => { +            dispatch(setShowDialog(false));              navigate("/models");            }}          > diff --git a/extension/react-app/src/components/index.ts b/extension/react-app/src/components/index.ts index 9d9b7c40..12b84759 100644 --- a/extension/react-app/src/components/index.ts +++ b/extension/react-app/src/components/index.ts @@ -10,9 +10,10 @@ export const vscBackgroundTransparent = "#1e1e1ede";  export const buttonColor = "#1bbe84";  export const buttonColorHover = "#1bbe84a8"; -export const secondaryDark = "var(--vscode-list-hoverBackground)"; -export const vscBackground = "var(--vscode-editor-background)"; -export const vscForeground = "var(--vscode-editor-foreground)"; +export const secondaryDark = +  "var(--vscode-list-hoverBackground, rgb(45 45 45))"; +export const vscBackground = "var(--vscode-editor-background, rgb(30 30 30))"; +export const vscForeground = "var(--vscode-editor-foreground, white)";  export const Button = styled.button`    padding: 10px 12px; @@ -92,7 +93,7 @@ export const H3 = styled.h3`    width: fit-content;  `; -export const TextInput = styled.input.attrs({ type: "text" })` +export const Input = styled.input`    width: 100%;    padding: 8px 12px;    margin: 8px 0; @@ -106,6 +107,10 @@ export const TextInput = styled.input.attrs({ type: "text" })`    &:focus {      background: ${secondaryDark};    } + +  &:invalid { +    outline: 1px solid red; +  }  `;  export const NumberInput = styled.input.attrs({ type: "number" })` diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index b8199c19..637896c6 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -1,5 +1,5 @@  import styled from "styled-components"; -import { TextInput, defaultBorderRadius, lightGray } from "../components"; +import { Input, defaultBorderRadius, lightGray, vscBackground } from "../components";  import { FullState } from "../../../schema/FullState";  import {    useEffect, @@ -58,7 +58,7 @@ const TopGuiDiv = styled.div`    }  `; -const TitleTextInput = styled(TextInput)` +const TitleTextInput = styled(Input)`    border: none;    outline: none; @@ -109,6 +109,10 @@ const GUIHeaderDiv = styled.div`    padding-left: 8px;    padding-right: 8px;    border-bottom: 0.5px solid ${lightGray}; +  position: sticky; +  top: 0; +  z-index: 100; +  background-color: ${vscBackground};  `;  interface GUIProps { @@ -480,7 +484,7 @@ function GUI(props: GUIProps) {    useEffect(() => {      const timeout = setTimeout(() => {        setShowLoading(true); -    }, 10000); +    }, 15_000);      return () => {        clearTimeout(timeout); diff --git a/extension/react-app/src/pages/history.tsx b/extension/react-app/src/pages/history.tsx index 63024e36..7c76cb53 100644 --- a/extension/react-app/src/pages/history.tsx +++ b/extension/react-app/src/pages/history.tsx @@ -17,6 +17,9 @@ const Tr = styled.tr`    }    overflow-wrap: anywhere; + +  border-bottom: 1px solid ${secondaryDark}; +  border-top: 1px solid ${secondaryDark};  `;  const parseDate = (date: string): Date => { @@ -44,7 +47,6 @@ const TdDiv = styled.div`    padding-right: 1rem;    padding-top: 0.5rem;    padding-bottom: 0.5rem; -  border-bottom: 1px solid ${secondaryDark};  `;  function lastPartOfPath(path: string): string { @@ -155,7 +157,7 @@ function History() {        )}        <div> -        <table className="w-full"> +        <table className="w-full border-spacing-0 border-collapse">            <tbody>              {filteredAndSortedSessions.map((session, index) => {                const prevDate = diff --git a/extension/react-app/src/pages/modelconfig.tsx b/extension/react-app/src/pages/modelconfig.tsx index 97e2d76c..00d9d9bf 100644 --- a/extension/react-app/src/pages/modelconfig.tsx +++ b/extension/react-app/src/pages/modelconfig.tsx @@ -3,7 +3,7 @@ import ModelCard from "../components/ModelCard";  import styled from "styled-components";  import { ArrowLeftIcon } from "@heroicons/react/24/outline";  import { -  TextInput, +  Input,    defaultBorderRadius,    lightGray,    vscBackground, @@ -22,6 +22,7 @@ import { RootStore } from "../redux/store";  import StyledMarkdownPreview from "../components/StyledMarkdownPreview";  import { getFontSize } from "../util";  import { FormProvider, useForm } from "react-hook-form"; +import _ from "lodash";  const GridDiv = styled.div`    display: grid; @@ -151,22 +152,28 @@ function ModelConfig() {              <>                <h3 className="mb-2">Enter required parameters</h3> -              {modelInfo?.collectInputFor?.map((d) => { -                return ( -                  <div> -                    <label htmlFor={d.key}>{d.key}</label> -                    <TextInput -                      id={d.key} -                      className="border-2 border-gray-200 rounded-md p-2 m-2" -                      placeholder={d.key} -                      defaultValue={d.defaultValue} -                      {...formMethods.register(d.key, { -                        required: true, -                      })} -                    /> -                  </div> -                ); -              })} +              {modelInfo?.collectInputFor +                ?.filter((d) => d.required) +                .map((d) => { +                  return ( +                    <div> +                      <label htmlFor={d.key}>{d.key}</label> +                      <Input +                        type={d.inputType} +                        id={d.key} +                        className="border-2 border-gray-200 rounded-md p-2 m-2" +                        placeholder={d.key} +                        defaultValue={d.defaultValue} +                        min={d.min} +                        max={d.max} +                        step={d.step} +                        {...formMethods.register(d.key, { +                          required: true, +                        })} +                      /> +                    </div> +                  ); +                })}              </>            )} @@ -182,11 +189,15 @@ function ModelConfig() {                  return (                    <div>                      <label htmlFor={d.key}>{d.key}</label> -                    <TextInput +                    <Input +                      type={d.inputType}                        id={d.key}                        className="border-2 border-gray-200 rounded-md p-2 m-2"                        placeholder={d.key}                        defaultValue={d.defaultValue} +                      min={d.min} +                      max={d.max} +                      step={d.step}                        {...formMethods.register(d.key, {                          required: false,                        })} @@ -209,19 +220,29 @@ function ModelConfig() {                  tags={pkg.tags}                  refUrl={pkg.refUrl}                  icon={pkg.icon || modelInfo.icon} -                onClick={(e) => { +                dimensions={pkg.dimensions} +                onClick={(e, dimensionChoices) => {                    if (disableModelCards()) return;                    const formParams: any = {};                    for (const d of modelInfo.collectInputFor || []) {                      formParams[d.key] =                        d.inputType === "text"                          ? formMethods.watch(d.key) -                        : parseInt(formMethods.watch(d.key)); +                        : parseFloat(formMethods.watch(d.key));                    }                    client?.addModelForRole("*", modelInfo.class, {                      ...pkg.params,                      ...modelInfo.params, +                    ..._.merge( +                      {}, +                      ...(pkg.dimensions?.map((dimension, i) => { +                        if (!dimensionChoices?.[i]) return {}; +                        return { +                          ...dimension.options[dimensionChoices[i]], +                        }; +                      }) || []) +                    ),                      ...formParams,                    });                    navigate("/"); @@ -239,7 +260,7 @@ function ModelConfig() {                  formParams[d.key] =                    d.inputType === "text"                      ? formMethods.watch(d.key) -                    : parseInt(formMethods.watch(d.key)); +                    : parseFloat(formMethods.watch(d.key));                }                client?.addModelForRole("*", modelInfo.class, { diff --git a/extension/react-app/src/pages/models.tsx b/extension/react-app/src/pages/models.tsx index a9a97a13..75c76d67 100644 --- a/extension/react-app/src/pages/models.tsx +++ b/extension/react-app/src/pages/models.tsx @@ -51,9 +51,6 @@ function Models() {              icon={modelInfo.icon}              refUrl={`https://continue.dev/docs/reference/Models/${modelInfo.class.toLowerCase()}`}              onClick={(e) => { -              if ((e.target as any).closest("a")) { -                return; -              }                navigate(`/modelconfig/${name}`);              }}            /> diff --git a/extension/react-app/src/pages/settings.tsx b/extension/react-app/src/pages/settings.tsx index cb269d7b..060a5b75 100644 --- a/extension/react-app/src/pages/settings.tsx +++ b/extension/react-app/src/pages/settings.tsx @@ -1,4 +1,4 @@ -import React, { useContext } from "react"; +import React, { useContext, useEffect } from "react";  import { GUIClientContext } from "../App";  import { useDispatch, useSelector } from "react-redux";  import { RootStore } from "../redux/store"; @@ -113,6 +113,13 @@ function Settings() {      navigate("/");    }; +  useEffect(() => { +    if (!config) return; + +    formMethods.setValue("system_message", config.system_message); +    formMethods.setValue("temperature", config.temperature); +  }, [config]); +    return (      <FormProvider {...formMethods}>        <div className="overflow-scroll"> @@ -145,7 +152,6 @@ function Settings() {                <TextArea                  placeholder="Enter a system message (e.g. 'Always respond in German')"                  {...formMethods.register("system_message")} -                defaultValue={config.system_message}                />                <Hr /> @@ -164,7 +170,6 @@ function Settings() {                    min="0"                    max="1"                    step="0.01" -                  defaultValue={config.temperature}                    {...formMethods.register("temperature")}                  />                  <p>1</p> diff --git a/extension/react-app/src/util/modelData.ts b/extension/react-app/src/util/modelData.ts index 91259446..035e4af2 100644 --- a/extension/react-app/src/util/modelData.ts +++ b/extension/react-app/src/util/modelData.ts @@ -1,3 +1,17 @@ +import _ from "lodash"; + +function updatedObj(old: any, pathToValue: { [key: string]: any }) { +  const newObject = _.cloneDeep(old); +  for (const key in pathToValue) { +    if (typeof pathToValue[key] === "function") { +      _.updateWith(newObject, key, pathToValue[key]); +    } else { +      _.updateWith(newObject, key, (__) => pathToValue[key]); +    } +  } +  return newObject; +} +  export enum ModelProviderTag {    "Requires API Key" = "Requires API Key",    "Local" = "Local", @@ -14,6 +28,7 @@ MODEL_PROVIDER_TAG_COLORS[ModelProviderTag["Free"]] = "#ffff00";  export enum CollectInputType {    "text" = "text",    "number" = "number", +  "range" = "range",  }  export interface InputDescriptor { @@ -38,6 +53,64 @@ const contextLengthInput: InputDescriptor = {    defaultValue: 2048,    required: false,  }; +const temperatureInput: InputDescriptor = { +  inputType: CollectInputType.number, +  key: "temperature", +  label: "Temperature", +  defaultValue: undefined, +  required: false, +  min: 0.0, +  max: 1.0, +  step: 0.01, +}; +const topPInput: InputDescriptor = { +  inputType: CollectInputType.number, +  key: "top_p", +  label: "Top-P", +  defaultValue: undefined, +  required: false, +  min: 0, +  max: 1, +  step: 0.01, +}; +const topKInput: InputDescriptor = { +  inputType: CollectInputType.number, +  key: "top_k", +  label: "Top-K", +  defaultValue: undefined, +  required: false, +  min: 0, +  max: 1, +  step: 0.01, +}; +const presencePenaltyInput: InputDescriptor = { +  inputType: CollectInputType.number, +  key: "presence_penalty", +  label: "Presence Penalty", +  defaultValue: undefined, +  required: false, +  min: 0, +  max: 1, +  step: 0.01, +}; +const FrequencyPenaltyInput: InputDescriptor = { +  inputType: CollectInputType.number, +  key: "frequency_penalty", +  label: "Frequency Penalty", +  defaultValue: undefined, +  required: false, +  min: 0, +  max: 1, +  step: 0.01, +}; +const completionParamsInputs = [ +  contextLengthInput, +  temperatureInput, +  topKInput, +  topPInput, +  presencePenaltyInput, +  FrequencyPenaltyInput, +];  const serverUrlInput = {    inputType: CollectInputType.text, @@ -59,6 +132,14 @@ export interface ModelInfo {    collectInputFor?: InputDescriptor[];  } +// A dimension is like parameter count - 7b, 13b, 34b, etc. +// You would set options to the field that should be changed for that option in the params field of ModelPackage +export interface PackageDimension { +  name: string; +  description: string; +  options: { [key: string]: { [key: string]: any } }; +} +  export interface ModelPackage {    collectInputFor?: InputDescriptor[];    description: string; @@ -75,100 +156,189 @@ export interface ModelPackage {      replace?: [string, string][];      [key: string]: any;    }; +  dimensions?: PackageDimension[];  } -const codeLlama7bInstruct: ModelPackage = { -  title: "CodeLlama-7b-Instruct", -  description: "A 7b parameter model tuned for code generation", +enum ChatTemplates { +  "alpaca" = "template_alpaca_messages", +  "llama2" = "llama2_template_messages", +  "sqlcoder" = "sqlcoder_template_messages", +} + +const codeLlamaInstruct: ModelPackage = { +  title: "CodeLlama Instruct", +  description: +    "A model from Meta, fine-tuned for code generation and conversation",    refUrl: "",    params: {      title: "CodeLlama-7b-Instruct",      model: "codellama:7b-instruct",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.llama2,    },    icon: "meta.svg", +  dimensions: [ +    { +      name: "Parameter Count", +      description: "The number of parameters in the model", +      options: { +        "7b": { +          model: "codellama:7b-instruct", +          title: "CodeLlama-7b-Instruct", +        }, +        "13b": { +          model: "codellama:13b-instruct", +          title: "CodeLlama-13b-Instruct", +        }, +        "34b": { +          model: "codellama:34b-instruct", +          title: "CodeLlama-34b-Instruct", +        }, +      }, +    }, +  ],  }; -const codeLlama13bInstruct: ModelPackage = { -  title: "CodeLlama-13b-Instruct", -  description: "A 13b parameter model tuned for code generation", + +const llama2Chat: ModelPackage = { +  title: "Llama2 Chat", +  description: "The latest Llama model from Meta, fine-tuned for chat",    refUrl: "",    params: { -    title: "CodeLlama13b-Instruct", -    model: "codellama13b-instruct", +    title: "Llama2-7b-Chat", +    model: "llama2:7b-chat",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.llama2,    },    icon: "meta.svg", +  dimensions: [ +    { +      name: "Parameter Count", +      description: "The number of parameters in the model", +      options: { +        "7b": { +          model: "llama2:7b-chat", +          title: "Llama2-7b-Chat", +        }, +        "13b": { +          model: "llama2:13b-chat", +          title: "Llama2-13b-Chat", +        }, +        "34b": { +          model: "llama2:34b-chat", +          title: "Llama2-34b-Chat", +        }, +      }, +    }, +  ],  }; -const codeLlama34bInstruct: ModelPackage = { -  title: "CodeLlama-34b-Instruct", -  description: "A 34b parameter model tuned for code generation", + +const wizardCoder: ModelPackage = { +  title: "WizardCoder", +  description: +    "A CodeLlama-based code generation model from WizardLM, focused on Python",    refUrl: "",    params: { -    title: "CodeLlama-34b-Instruct", -    model: "codellama:34b-instruct", +    title: "WizardCoder-7b-Python", +    model: "wizardcoder:7b-python",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.alpaca,    }, -  icon: "meta.svg", +  icon: "wizardlm.png", +  dimensions: [ +    { +      name: "Parameter Count", +      description: "The number of parameters in the model", +      options: { +        "7b": { +          model: "wizardcoder:7b-python", +          title: "WizardCoder-7b-Python", +        }, +        "13b": { +          model: "wizardcoder:13b-python", +          title: "WizardCoder-13b-Python", +        }, +        "34b": { +          model: "wizardcoder:34b-python", +          title: "WizardCoder-34b-Python", +        }, +      }, +    }, +  ],  }; -const llama2Chat7b: ModelPackage = { -  title: "Llama2-7b-Chat", -  description: "A 7b parameter model fine-tuned for chat", -  refUrl: "", +const phindCodeLlama: ModelPackage = { +  title: "Phind CodeLlama (34b)", +  description: "A finetune of CodeLlama by Phind",    params: { -    title: "Llama2-7b-Chat", -    model: "llama2:7b-chat", +    title: "Phind CodeLlama", +    model: "phind-codellama",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.llama2,    }, -  icon: "meta.svg",  }; -const llama2Chat13b: ModelPackage = { -  title: "Llama2-13b-Chat", -  description: "A 13b parameter model fine-tuned for chat", -  refUrl: "", + +const mistral: ModelPackage = { +  title: "Mistral (7b)", +  description: +    "A 7b parameter base model created by Mistral AI, very competent for code generation and other tasks",    params: { -    title: "Llama2-13b-Chat", -    model: "llama2:13b-chat", +    title: "Mistral", +    model: "mistral",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.llama2,    }, -  icon: "meta.svg", +  icon: "mistral.png",  }; -const llama2Chat34b: ModelPackage = { -  title: "Llama2-34b-Chat", -  description: "A 34b parameter model fine-tuned for chat", -  refUrl: "", + +const sqlCoder: ModelPackage = { +  title: "SQLCoder", +  description: +    "A finetune of StarCoder by Defog.ai, focused specifically on SQL",    params: { -    title: "Llama2-34b-Chat", -    model: "llama2:34b-chat", +    title: "SQLCoder", +    model: "sqlcoder",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.sqlcoder,    }, -  icon: "meta.svg", +  dimensions: [ +    { +      name: "Parameter Count", +      description: "The number of parameters in the model", +      options: { +        "7b": { +          model: "sqlcoder:7b", +          title: "SQLCoder-7b", +        }, +        "13b": { +          model: "sqlcoder:15b", +          title: "SQLCoder-15b", +        }, +      }, +    }, +  ],  }; -const codeLlamaPackages = [ -  codeLlama7bInstruct, -  codeLlama13bInstruct, -  codeLlama34bInstruct, -]; - -const llama2Packages = [llama2Chat7b, llama2Chat13b, llama2Chat34b]; -const llama2FamilyPackage = { -  title: "Llama2 or CodeLlama", -  description: "Any model using the Llama2 or CodeLlama chat template", +const codeup: ModelPackage = { +  title: "CodeUp (13b)", +  description: "An open-source coding model based on Llama2",    params: { -    model: "llama2", +    title: "CodeUp", +    model: "codeup",      context_length: 2048, -    template_messages: "llama2_template_messages", +    template_messages: ChatTemplates.llama2,    }, -  icon: "meta.svg",  }; +const osModels = [ +  codeLlamaInstruct, +  llama2Chat, +  wizardCoder, +  phindCodeLlama, +  sqlCoder, +  mistral, +  codeup, +]; +  const gpt4: ModelPackage = {    title: "GPT-4",    description: "The latest model from OpenAI", @@ -192,6 +362,23 @@ const gpt35turbo: ModelPackage = {    },  }; +const OLLAMA_TO_REPLICATE_MODEL_NAMES: { [key: string]: string } = { +  "codellama:7b-instruct": +    "meta/codellama-7b-instruct:6527b83e01e41412db37de5110a8670e3701ee95872697481a355e05ce12af0e", +  "codellama:13b-instruct": +    "meta/codellama-13b-instruct:1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7", +  "codellama:34b-instruct": +    "meta/codellama-34b-instruct:8281a5c610f6e88237ff3ddaf3c33b56f60809e2bdd19fbec2fda742aa18167e", +  "llama2:7b-chat": +    "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e", +  "llama2:13b-chat": +    "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", +}; + +function replicateConvertModelName(model: string): string { +  return OLLAMA_TO_REPLICATE_MODEL_NAMES[model] || model; +} +  export const MODEL_INFO: { [key: string]: ModelInfo } = {    openai: {      title: "OpenAI", @@ -210,6 +397,7 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {          placeholder: "Enter your OpenAI API key",          required: true,        }, +      ...completionParamsInputs,      ],    },    anthropic: { @@ -229,6 +417,7 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {          placeholder: "Enter your Anthropic API key",          required: true,        }, +      ...completionParamsInputs,      ],      packages: [        { @@ -251,17 +440,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {        'To get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama pull <MODEL_NAME>`. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.py (e.g. `model="codellama:7b-instruct"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.',      icon: "ollama.png",      tags: [ModelProviderTag["Local"], ModelProviderTag["Open-Source"]], -    packages: [ -      ...codeLlamaPackages.map((p) => ({ -        ...p, -        refUrl: "https://ollama.ai/library/codellama", -      })), -      ...llama2Packages.map((p) => ({ -        ...p, -        refUrl: "https://ollama.ai/library/llama2", -      })), -    ], -    collectInputFor: [contextLengthInput], +    packages: osModels, +    collectInputFor: [...completionParamsInputs],    },    together: {      title: "TogetherAI", @@ -285,32 +465,51 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {          placeholder: "Enter your TogetherAI API key",          required: true,        }, +      ...completionParamsInputs,      ],      packages: [ -      ...codeLlamaPackages.map((p) => { -        return { -          ...p, -          params: { -            ...p.params, -            model: -              "togethercomputer/" + -              p.params.model.replace("llama2", "llama-2").replace(":", "-"), -          }, -        }; +      updatedObj(llama2Chat, { +        "dimensions[0].options": (options: any) => +          _.mapValues(options, (option) => { +            return _.assign({}, option, { +              model: +                "togethercomputer/" + +                option.model.replace("llama2", "llama-2").replace(":", "-"), +            }); +          }), +      }), +      updatedObj(codeLlamaInstruct, { +        "dimensions[0].options": (options: any) => +          _.mapValues(options, (option) => { +            return _.assign({}, option, { +              model: +                "togethercomputer/" + +                option.model +                  .replace("codellama", "CodeLlama") +                  .replace(":", "-") +                  .replace("instruct", "Instruct"), +            }); +          }),        }), -      ...llama2Packages.map((p) => { -        return { -          ...p, -          params: { -            ...p.params, -            model: -              "togethercomputer/" + -              p.params.model -                .replace("codellama", "CodeLlama") -                .replace(":", "-") -                .replace("instruct", "Instruct"), +      updatedObj(wizardCoder, { +        "params.model": "WizardLM/WizardCoder-15B-V1.0", +        "params.title": "WizardCoder-15b", +        "dimensions[0].options": { +          "15b": { +            model: "WizardLM/WizardCoder-15B-V1.0", +            title: "WizardCoder-15b",            }, -        }; +          "34b (Python)": { +            model: "WizardLM/WizardCoder-Python-34B-V1.0", +            title: "WizardCoder-34b-Python", +          }, +        }, +      }), +      updatedObj(phindCodeLlama, { +        "params.model": "Phind/Phind-CodeLlama-34B-Python-v1", +      }), +      updatedObj(mistral, { +        "params.model": "mistralai/Mistral-7B-Instruct-v0.1",        }),      ].map((p) => {        p.params.context_length = 4096; @@ -329,8 +528,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {      params: {        server_url: "http://localhost:1234",      }, -    packages: [llama2FamilyPackage], -    collectInputFor: [contextLengthInput], +    packages: osModels, +    collectInputFor: [...completionParamsInputs],    },    replicate: {      title: "Replicate", @@ -348,23 +547,62 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {          placeholder: "Enter your Replicate API key",          required: true,        }, +      ...completionParamsInputs,      ],      icon: "replicate.png",      tags: [        ModelProviderTag["Requires API Key"],        ModelProviderTag["Open-Source"],      ], -    packages: [...codeLlamaPackages, ...llama2Packages].map((p) => { -      return { -        ...p, -        params: { -          ...p.params, -          model: -            "meta/" + -            p.params.model.replace(":", "-").replace("llama2", "llama-2"), -        }, -      }; -    }), +    packages: [ +      ...[codeLlamaInstruct, llama2Chat] +        .map((p: ModelPackage) => { +          if (p.title === "Llama2 Chat") { +            return updatedObj(p, { +              "dimensions[0].options.34b": undefined, +              "dimensions[0].options.70b": { +                model: +                  "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", +                title: "Llama2-70b-Chat", +              }, +            }); +          } +          return p; +        }) +        .map((p) => { +          return updatedObj(p, { +            "params.model": (model: string) => { +              return replicateConvertModelName(model); +            }, +            "dimensions[0].options": (options: any) => { +              const newOptions: any = {}; +              for (const key in options) { +                newOptions[key] = { +                  ...options[key], +                  model: replicateConvertModelName(options[key]?.model), +                }; +              } +              return newOptions; +            }, +          }); +        }), +      updatedObj(wizardCoder, { +        title: "WizardCoder (15b)", +        "params.model": +          "andreasjansson/wizardcoder-python-34b-v1-gguf:67eed332a5389263b8ede41be3ee7dc119fa984e2bde287814c4abed19a45e54", +        dimensions: undefined, +      }), +      updatedObj(sqlCoder, { +        dimensions: undefined, +        title: "SQLCoder (15b)", +        "params.model": +          "gregwdata/defog-sqlcoder-q8:0a9abc0d143072fd5d8920ad90b8fbaafaf16b10ffdad24bd897b5bffacfce0b", +      }), +      updatedObj(mistral, { +        "params.model": +          "a16z-infra/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70", +      }), +    ],    },    llamacpp: {      title: "llama.cpp", @@ -384,8 +622,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {  After it's up and running, you can start using Continue.`,      icon: "llamacpp.png",      tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]], -    packages: [llama2FamilyPackage], -    collectInputFor: [contextLengthInput], +    packages: osModels, +    collectInputFor: [...completionParamsInputs],    },    palm: {      title: "Google PaLM API", @@ -426,9 +664,9 @@ After it's up and running, you can start using Continue.`,        "HuggingFace Text Generation Inference is an advanced, highly-performant option for serving open-source models to multiple people. To get started, follow the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour) on their website to set up the Docker container. Make sure to enter the server URL below that corresponds to the host and port you set up for the Docker container.",      icon: "hf.png",      tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]], -    packages: [llama2FamilyPackage], +    packages: osModels,      collectInputFor: [ -      contextLengthInput, +      ...completionParamsInputs,        { ...serverUrlInput, defaultValue: "http://localhost:8080" },      ],    }, @@ -451,11 +689,11 @@ After it's up and running, you can start using Continue.`,          ...serverUrlInput,          defaultValue: "http://localhost:8000",        }, -      contextLengthInput, +      ...completionParamsInputs,      ],      icon: "openai.svg",      tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]], -    packages: [llama2FamilyPackage], +    packages: osModels,    },    freetrial: {      title: "GPT-4 limited free trial", @@ -467,5 +705,6 @@ After it's up and running, you can start using Continue.`,      icon: "openai.svg",      tags: [ModelProviderTag.Free],      packages: [gpt4, gpt35turbo], +    collectInputFor: [...completionParamsInputs],    },  }; | 
