From 2f792f46026a6bb3c3580f2521b01ecb8c68117c Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Fri, 1 Sep 2023 18:31:33 -0700
Subject: feat: :sparkles: improved model dropdown

---
 extension/react-app/src/components/ModelSelect.tsx | 76 +++++++++++++++++-----
 .../src/hooks/AbstractContinueGUIClientProtocol.ts |  2 +-
 .../src/hooks/ContinueGUIClientProtocol.ts         |  8 ++-
 3 files changed, 68 insertions(+), 18 deletions(-)

(limited to 'extension/react-app/src')

diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx
index ea979da7..1cbf3f0e 100644
--- a/extension/react-app/src/components/ModelSelect.tsx
+++ b/extension/react-app/src/components/ModelSelect.tsx
@@ -73,12 +73,11 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [
 
 const Select = styled.select`
   border: none;
-  width: fit-content;
+  width: 25vw;
   background-color: ${secondaryDark};
   color: ${vscForeground};
   border-radius: ${defaultBorderRadius};
   padding: 6px;
-  /* box-shadow: 0px 0px 1px 0px ${vscForeground}; */
   max-height: 35vh;
   overflow: scroll;
   cursor: pointer;
@@ -89,34 +88,81 @@ const Select = styled.select`
   }
 `;
 
+function modelSelectTitle(model: any): string {
+  if (model.title) return model.title;
+  if (model.model !== undefined && model.model.trim() !== "") {
+    if (model.class_name) {
+      return `${model.class_name} - ${model.model}`;
+    }
+    return model.model;
+  }
+  return model.class_name;
+}
+
 function ModelSelect(props: {}) {
   const client = useContext(GUIClientContext);
   const defaultModel = useSelector(
     (state: RootStore) => (state.serverState.config as any)?.models?.default
   );
+  const unusedModels = useSelector(
+    (state: RootStore) => (state.serverState.config as any)?.models?.unused
+  );
 
   return (
     <Select
+      value={JSON.stringify({
+        t: "default",
+        idx: -1,
+      })}
       defaultValue={0}
       onChange={(e) => {
-        const model = MODEL_INFO[parseInt(e.target.value)];
-        client?.setModelForRole("*", model.class, model.args);
+        const value = JSON.parse(e.target.value);
+        if (value.t === "unused") {
+          client?.setModelForRoleFromIndex("*", value.idx);
+        } else if (value.t === "new") {
+          const model = MODEL_INFO[value.idx];
+          client?.addModelForRole("*", model.class, model.args);
+        }
       }}
     >
-      {MODEL_INFO.map((model, idx) => {
-        return (
+      <optgroup label="My Saved Models">
+        {defaultModel && (
           <option
-            selected={
-              defaultModel?.class_name === model.class &&
-              (!defaultModel?.model?.startsWith("gpt") ||
-                defaultModel?.model === model.args.model)
-            }
-            value={idx}
+            value={JSON.stringify({
+              t: "default",
+              idx: -1,
+            })}
           >
-            {model.title}
+            {modelSelectTitle(defaultModel)}
           </option>
-        );
-      })}
+        )}
+        {unusedModels?.map((model: any, idx: number) => {
+          return (
+            <option
+              value={JSON.stringify({
+                t: "unused",
+                idx,
+              })}
+            >
+              {modelSelectTitle(model)}
+            </option>
+          );
+        })}
+      </optgroup>
+      <optgroup label="Add New Model">
+        {MODEL_INFO.map((model, idx) => {
+          return (
+            <option
+              value={JSON.stringify({
+                t: "new",
+                idx,
+              })}
+            >
+              {model.title}
+            </option>
+          );
+        })}
+      </optgroup>
     </Select>
   );
 }
diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
index f8c11527..9944f221 100644
--- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
+++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts
@@ -41,7 +41,7 @@ abstract class AbstractContinueGUIClientProtocol {
 
   abstract setTemperature(temperature: number): void;
 
-  abstract setModelForRole(
+  abstract addModelForRole(
     role: string,
     model_class: string,
     model: string
diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
index ce9b2a0a..fe1b654b 100644
--- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
+++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts
@@ -141,8 +141,12 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol {
     this.messenger?.send("set_temperature", { temperature });
   }
 
-  setModelForRole(role: string, model_class: string, model: any): void {
-    this.messenger?.send("set_model_for_role", { role, model, model_class });
+  addModelForRole(role: string, model_class: string, model: any): void {
+    this.messenger?.send("add_model_for_role", { role, model, model_class });
+  }
+
+  setModelForRoleFromIndex(role: string, index: number): void {
+    this.messenger?.send("set_model_for_role_from_index", { role, index });
   }
 
   saveContextGroup(title: string, contextItems: ContextItem[]): void {
-- 
cgit v1.2.3-70-g09d2