summaryrefslogtreecommitdiff
path: root/extension/react-app/src/util/modelData.ts
diff options
context:
space:
mode:
Diffstat (limited to 'extension/react-app/src/util/modelData.ts')
-rw-r--r--extension/react-app/src/util/modelData.ts451
1 files changed, 345 insertions, 106 deletions
diff --git a/extension/react-app/src/util/modelData.ts b/extension/react-app/src/util/modelData.ts
index 91259446..035e4af2 100644
--- a/extension/react-app/src/util/modelData.ts
+++ b/extension/react-app/src/util/modelData.ts
@@ -1,3 +1,17 @@
+import _ from "lodash";
+
+function updatedObj(old: any, pathToValue: { [key: string]: any }) {
+ const newObject = _.cloneDeep(old);
+ for (const key in pathToValue) {
+ if (typeof pathToValue[key] === "function") {
+ _.updateWith(newObject, key, pathToValue[key]);
+ } else {
+ _.updateWith(newObject, key, (__) => pathToValue[key]);
+ }
+ }
+ return newObject;
+}
+
export enum ModelProviderTag {
"Requires API Key" = "Requires API Key",
"Local" = "Local",
@@ -14,6 +28,7 @@ MODEL_PROVIDER_TAG_COLORS[ModelProviderTag["Free"]] = "#ffff00";
export enum CollectInputType {
"text" = "text",
"number" = "number",
+ "range" = "range",
}
export interface InputDescriptor {
@@ -38,6 +53,64 @@ const contextLengthInput: InputDescriptor = {
defaultValue: 2048,
required: false,
};
+const temperatureInput: InputDescriptor = {
+ inputType: CollectInputType.number,
+ key: "temperature",
+ label: "Temperature",
+ defaultValue: undefined,
+ required: false,
+ min: 0.0,
+ max: 1.0,
+ step: 0.01,
+};
+const topPInput: InputDescriptor = {
+ inputType: CollectInputType.number,
+ key: "top_p",
+ label: "Top-P",
+ defaultValue: undefined,
+ required: false,
+ min: 0,
+ max: 1,
+ step: 0.01,
+};
+const topKInput: InputDescriptor = {
+ inputType: CollectInputType.number,
+ key: "top_k",
+ label: "Top-K",
+ defaultValue: undefined,
+ required: false,
+ min: 0,
+ max: 1,
+ step: 0.01,
+};
+const presencePenaltyInput: InputDescriptor = {
+ inputType: CollectInputType.number,
+ key: "presence_penalty",
+ label: "Presence Penalty",
+ defaultValue: undefined,
+ required: false,
+ min: 0,
+ max: 1,
+ step: 0.01,
+};
+const FrequencyPenaltyInput: InputDescriptor = {
+ inputType: CollectInputType.number,
+ key: "frequency_penalty",
+ label: "Frequency Penalty",
+ defaultValue: undefined,
+ required: false,
+ min: 0,
+ max: 1,
+ step: 0.01,
+};
+const completionParamsInputs = [
+ contextLengthInput,
+ temperatureInput,
+ topKInput,
+ topPInput,
+ presencePenaltyInput,
+ FrequencyPenaltyInput,
+];
const serverUrlInput = {
inputType: CollectInputType.text,
@@ -59,6 +132,14 @@ export interface ModelInfo {
collectInputFor?: InputDescriptor[];
}
+// A dimension is like parameter count - 7b, 13b, 34b, etc.
+// You would set options to the field that should be changed for that option in the params field of ModelPackage
+export interface PackageDimension {
+ name: string;
+ description: string;
+ options: { [key: string]: { [key: string]: any } };
+}
+
export interface ModelPackage {
collectInputFor?: InputDescriptor[];
description: string;
@@ -75,100 +156,189 @@ export interface ModelPackage {
replace?: [string, string][];
[key: string]: any;
};
+ dimensions?: PackageDimension[];
}
-const codeLlama7bInstruct: ModelPackage = {
- title: "CodeLlama-7b-Instruct",
- description: "A 7b parameter model tuned for code generation",
+enum ChatTemplates {
+ "alpaca" = "template_alpaca_messages",
+ "llama2" = "llama2_template_messages",
+ "sqlcoder" = "sqlcoder_template_messages",
+}
+
+const codeLlamaInstruct: ModelPackage = {
+ title: "CodeLlama Instruct",
+ description:
+ "A model from Meta, fine-tuned for code generation and conversation",
refUrl: "",
params: {
title: "CodeLlama-7b-Instruct",
model: "codellama:7b-instruct",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.llama2,
},
icon: "meta.svg",
+ dimensions: [
+ {
+ name: "Parameter Count",
+ description: "The number of parameters in the model",
+ options: {
+ "7b": {
+ model: "codellama:7b-instruct",
+ title: "CodeLlama-7b-Instruct",
+ },
+ "13b": {
+ model: "codellama:13b-instruct",
+ title: "CodeLlama-13b-Instruct",
+ },
+ "34b": {
+ model: "codellama:34b-instruct",
+ title: "CodeLlama-34b-Instruct",
+ },
+ },
+ },
+ ],
};
-const codeLlama13bInstruct: ModelPackage = {
- title: "CodeLlama-13b-Instruct",
- description: "A 13b parameter model tuned for code generation",
+
+const llama2Chat: ModelPackage = {
+ title: "Llama2 Chat",
+ description: "The latest Llama model from Meta, fine-tuned for chat",
refUrl: "",
params: {
- title: "CodeLlama13b-Instruct",
- model: "codellama13b-instruct",
+ title: "Llama2-7b-Chat",
+ model: "llama2:7b-chat",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.llama2,
},
icon: "meta.svg",
+ dimensions: [
+ {
+ name: "Parameter Count",
+ description: "The number of parameters in the model",
+ options: {
+ "7b": {
+ model: "llama2:7b-chat",
+ title: "Llama2-7b-Chat",
+ },
+ "13b": {
+ model: "llama2:13b-chat",
+ title: "Llama2-13b-Chat",
+ },
+ "34b": {
+ model: "llama2:34b-chat",
+ title: "Llama2-34b-Chat",
+ },
+ },
+ },
+ ],
};
-const codeLlama34bInstruct: ModelPackage = {
- title: "CodeLlama-34b-Instruct",
- description: "A 34b parameter model tuned for code generation",
+
+const wizardCoder: ModelPackage = {
+ title: "WizardCoder",
+ description:
+ "A CodeLlama-based code generation model from WizardLM, focused on Python",
refUrl: "",
params: {
- title: "CodeLlama-34b-Instruct",
- model: "codellama:34b-instruct",
+ title: "WizardCoder-7b-Python",
+ model: "wizardcoder:7b-python",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.alpaca,
},
- icon: "meta.svg",
+ icon: "wizardlm.png",
+ dimensions: [
+ {
+ name: "Parameter Count",
+ description: "The number of parameters in the model",
+ options: {
+ "7b": {
+ model: "wizardcoder:7b-python",
+ title: "WizardCoder-7b-Python",
+ },
+ "13b": {
+ model: "wizardcoder:13b-python",
+ title: "WizardCoder-13b-Python",
+ },
+ "34b": {
+ model: "wizardcoder:34b-python",
+ title: "WizardCoder-34b-Python",
+ },
+ },
+ },
+ ],
};
-const llama2Chat7b: ModelPackage = {
- title: "Llama2-7b-Chat",
- description: "A 7b parameter model fine-tuned for chat",
- refUrl: "",
+const phindCodeLlama: ModelPackage = {
+ title: "Phind CodeLlama (34b)",
+ description: "A finetune of CodeLlama by Phind",
params: {
- title: "Llama2-7b-Chat",
- model: "llama2:7b-chat",
+ title: "Phind CodeLlama",
+ model: "phind-codellama",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.llama2,
},
- icon: "meta.svg",
};
-const llama2Chat13b: ModelPackage = {
- title: "Llama2-13b-Chat",
- description: "A 13b parameter model fine-tuned for chat",
- refUrl: "",
+
+const mistral: ModelPackage = {
+ title: "Mistral (7b)",
+ description:
+ "A 7b parameter base model created by Mistral AI, very competent for code generation and other tasks",
params: {
- title: "Llama2-13b-Chat",
- model: "llama2:13b-chat",
+ title: "Mistral",
+ model: "mistral",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.llama2,
},
- icon: "meta.svg",
+ icon: "mistral.png",
};
-const llama2Chat34b: ModelPackage = {
- title: "Llama2-34b-Chat",
- description: "A 34b parameter model fine-tuned for chat",
- refUrl: "",
+
+const sqlCoder: ModelPackage = {
+ title: "SQLCoder",
+ description:
+ "A finetune of StarCoder by Defog.ai, focused specifically on SQL",
params: {
- title: "Llama2-34b-Chat",
- model: "llama2:34b-chat",
+ title: "SQLCoder",
+ model: "sqlcoder",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.sqlcoder,
},
- icon: "meta.svg",
+ dimensions: [
+ {
+ name: "Parameter Count",
+ description: "The number of parameters in the model",
+ options: {
+ "7b": {
+ model: "sqlcoder:7b",
+ title: "SQLCoder-7b",
+ },
+ "13b": {
+ model: "sqlcoder:15b",
+ title: "SQLCoder-15b",
+ },
+ },
+ },
+ ],
};
-const codeLlamaPackages = [
- codeLlama7bInstruct,
- codeLlama13bInstruct,
- codeLlama34bInstruct,
-];
-
-const llama2Packages = [llama2Chat7b, llama2Chat13b, llama2Chat34b];
-const llama2FamilyPackage = {
- title: "Llama2 or CodeLlama",
- description: "Any model using the Llama2 or CodeLlama chat template",
+const codeup: ModelPackage = {
+ title: "CodeUp (13b)",
+ description: "An open-source coding model based on Llama2",
params: {
- model: "llama2",
+ title: "CodeUp",
+ model: "codeup",
context_length: 2048,
- template_messages: "llama2_template_messages",
+ template_messages: ChatTemplates.llama2,
},
- icon: "meta.svg",
};
+const osModels = [
+ codeLlamaInstruct,
+ llama2Chat,
+ wizardCoder,
+ phindCodeLlama,
+ sqlCoder,
+ mistral,
+ codeup,
+];
+
const gpt4: ModelPackage = {
title: "GPT-4",
description: "The latest model from OpenAI",
@@ -192,6 +362,23 @@ const gpt35turbo: ModelPackage = {
},
};
+const OLLAMA_TO_REPLICATE_MODEL_NAMES: { [key: string]: string } = {
+ "codellama:7b-instruct":
+ "meta/codellama-7b-instruct:6527b83e01e41412db37de5110a8670e3701ee95872697481a355e05ce12af0e",
+ "codellama:13b-instruct":
+ "meta/codellama-13b-instruct:1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7",
+ "codellama:34b-instruct":
+ "meta/codellama-34b-instruct:8281a5c610f6e88237ff3ddaf3c33b56f60809e2bdd19fbec2fda742aa18167e",
+ "llama2:7b-chat":
+ "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e",
+ "llama2:13b-chat":
+ "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
+};
+
+function replicateConvertModelName(model: string): string {
+ return OLLAMA_TO_REPLICATE_MODEL_NAMES[model] || model;
+}
+
export const MODEL_INFO: { [key: string]: ModelInfo } = {
openai: {
title: "OpenAI",
@@ -210,6 +397,7 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
placeholder: "Enter your OpenAI API key",
required: true,
},
+ ...completionParamsInputs,
],
},
anthropic: {
@@ -229,6 +417,7 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
placeholder: "Enter your Anthropic API key",
required: true,
},
+ ...completionParamsInputs,
],
packages: [
{
@@ -251,17 +440,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
'To get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama pull <MODEL_NAME>`. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.py (e.g. `model="codellama:7b-instruct"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.',
icon: "ollama.png",
tags: [ModelProviderTag["Local"], ModelProviderTag["Open-Source"]],
- packages: [
- ...codeLlamaPackages.map((p) => ({
- ...p,
- refUrl: "https://ollama.ai/library/codellama",
- })),
- ...llama2Packages.map((p) => ({
- ...p,
- refUrl: "https://ollama.ai/library/llama2",
- })),
- ],
- collectInputFor: [contextLengthInput],
+ packages: osModels,
+ collectInputFor: [...completionParamsInputs],
},
together: {
title: "TogetherAI",
@@ -285,32 +465,51 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
placeholder: "Enter your TogetherAI API key",
required: true,
},
+ ...completionParamsInputs,
],
packages: [
- ...codeLlamaPackages.map((p) => {
- return {
- ...p,
- params: {
- ...p.params,
- model:
- "togethercomputer/" +
- p.params.model.replace("llama2", "llama-2").replace(":", "-"),
- },
- };
+ updatedObj(llama2Chat, {
+ "dimensions[0].options": (options: any) =>
+ _.mapValues(options, (option) => {
+ return _.assign({}, option, {
+ model:
+ "togethercomputer/" +
+ option.model.replace("llama2", "llama-2").replace(":", "-"),
+ });
+ }),
+ }),
+ updatedObj(codeLlamaInstruct, {
+ "dimensions[0].options": (options: any) =>
+ _.mapValues(options, (option) => {
+ return _.assign({}, option, {
+ model:
+ "togethercomputer/" +
+ option.model
+ .replace("codellama", "CodeLlama")
+ .replace(":", "-")
+ .replace("instruct", "Instruct"),
+ });
+ }),
}),
- ...llama2Packages.map((p) => {
- return {
- ...p,
- params: {
- ...p.params,
- model:
- "togethercomputer/" +
- p.params.model
- .replace("codellama", "CodeLlama")
- .replace(":", "-")
- .replace("instruct", "Instruct"),
+ updatedObj(wizardCoder, {
+ "params.model": "WizardLM/WizardCoder-15B-V1.0",
+ "params.title": "WizardCoder-15b",
+ "dimensions[0].options": {
+ "15b": {
+ model: "WizardLM/WizardCoder-15B-V1.0",
+ title: "WizardCoder-15b",
},
- };
+ "34b (Python)": {
+ model: "WizardLM/WizardCoder-Python-34B-V1.0",
+ title: "WizardCoder-34b-Python",
+ },
+ },
+ }),
+ updatedObj(phindCodeLlama, {
+ "params.model": "Phind/Phind-CodeLlama-34B-Python-v1",
+ }),
+ updatedObj(mistral, {
+ "params.model": "mistralai/Mistral-7B-Instruct-v0.1",
}),
].map((p) => {
p.params.context_length = 4096;
@@ -329,8 +528,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
params: {
server_url: "http://localhost:1234",
},
- packages: [llama2FamilyPackage],
- collectInputFor: [contextLengthInput],
+ packages: osModels,
+ collectInputFor: [...completionParamsInputs],
},
replicate: {
title: "Replicate",
@@ -348,23 +547,62 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
placeholder: "Enter your Replicate API key",
required: true,
},
+ ...completionParamsInputs,
],
icon: "replicate.png",
tags: [
ModelProviderTag["Requires API Key"],
ModelProviderTag["Open-Source"],
],
- packages: [...codeLlamaPackages, ...llama2Packages].map((p) => {
- return {
- ...p,
- params: {
- ...p.params,
- model:
- "meta/" +
- p.params.model.replace(":", "-").replace("llama2", "llama-2"),
- },
- };
- }),
+ packages: [
+ ...[codeLlamaInstruct, llama2Chat]
+ .map((p: ModelPackage) => {
+ if (p.title === "Llama2 Chat") {
+ return updatedObj(p, {
+ "dimensions[0].options.34b": undefined,
+ "dimensions[0].options.70b": {
+ model:
+ "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
+ title: "Llama2-70b-Chat",
+ },
+ });
+ }
+ return p;
+ })
+ .map((p) => {
+ return updatedObj(p, {
+ "params.model": (model: string) => {
+ return replicateConvertModelName(model);
+ },
+ "dimensions[0].options": (options: any) => {
+ const newOptions: any = {};
+ for (const key in options) {
+ newOptions[key] = {
+ ...options[key],
+ model: replicateConvertModelName(options[key]?.model),
+ };
+ }
+ return newOptions;
+ },
+ });
+ }),
+ updatedObj(wizardCoder, {
+ title: "WizardCoder (15b)",
+ "params.model":
+ "andreasjansson/wizardcoder-python-34b-v1-gguf:67eed332a5389263b8ede41be3ee7dc119fa984e2bde287814c4abed19a45e54",
+ dimensions: undefined,
+ }),
+ updatedObj(sqlCoder, {
+ dimensions: undefined,
+ title: "SQLCoder (15b)",
+ "params.model":
+ "gregwdata/defog-sqlcoder-q8:0a9abc0d143072fd5d8920ad90b8fbaafaf16b10ffdad24bd897b5bffacfce0b",
+ }),
+ updatedObj(mistral, {
+ "params.model":
+ "a16z-infra/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70",
+ }),
+ ],
},
llamacpp: {
title: "llama.cpp",
@@ -384,8 +622,8 @@ export const MODEL_INFO: { [key: string]: ModelInfo } = {
After it's up and running, you can start using Continue.`,
icon: "llamacpp.png",
tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]],
- packages: [llama2FamilyPackage],
- collectInputFor: [contextLengthInput],
+ packages: osModels,
+ collectInputFor: [...completionParamsInputs],
},
palm: {
title: "Google PaLM API",
@@ -426,9 +664,9 @@ After it's up and running, you can start using Continue.`,
"HuggingFace Text Generation Inference is an advanced, highly-performant option for serving open-source models to multiple people. To get started, follow the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour) on their website to set up the Docker container. Make sure to enter the server URL below that corresponds to the host and port you set up for the Docker container.",
icon: "hf.png",
tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]],
- packages: [llama2FamilyPackage],
+ packages: osModels,
collectInputFor: [
- contextLengthInput,
+ ...completionParamsInputs,
{ ...serverUrlInput, defaultValue: "http://localhost:8080" },
],
},
@@ -451,11 +689,11 @@ After it's up and running, you can start using Continue.`,
...serverUrlInput,
defaultValue: "http://localhost:8000",
},
- contextLengthInput,
+ ...completionParamsInputs,
],
icon: "openai.svg",
tags: [ModelProviderTag.Local, ModelProviderTag["Open-Source"]],
- packages: [llama2FamilyPackage],
+ packages: osModels,
},
freetrial: {
title: "GPT-4 limited free trial",
@@ -467,5 +705,6 @@ After it's up and running, you can start using Continue.`,
icon: "openai.svg",
tags: [ModelProviderTag.Free],
packages: [gpt4, gpt35turbo],
+ collectInputFor: [...completionParamsInputs],
},
};