summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/autopilot.py3
-rw-r--r--continuedev/src/continuedev/core/context.py2
-rw-r--r--continuedev/src/continuedev/core/models.py34
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py117
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py4
-rw-r--r--continuedev/src/continuedev/libs/util/commonregex.py2
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py43
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py3
-rw-r--r--continuedev/src/continuedev/plugins/steps/setup_model.py24
-rw-r--r--continuedev/src/continuedev/server/gui.py83
-rw-r--r--extension/react-app/src/components/Layout.tsx48
-rw-r--r--extension/react-app/src/components/ModelSelect.tsx113
14 files changed, 395 insertions, 89 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 4e7a7cc7..5d8e7f4e 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -507,7 +507,8 @@ class Autopilot(ContinueBaseModel):
async def create_title():
title = await self.continue_sdk.models.medium.complete(
- f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". The title is: '
+ f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ',
+ max_tokens=20,
)
title = remove_quotes_and_escapes(title)
self.session_info = SessionInfo(
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index d51a32e2..ffe22d63 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -257,6 +257,8 @@ class ContextManager:
await asyncio.wait_for(add_docs(), timeout=5)
except asyncio.TimeoutError:
logger.warning("Failed to add document to meilisearch in 5 seconds")
+ except Exception as e:
+ logger.warning(f"Error adding document to meilisearch: {e}")
@staticmethod
async def delete_documents(ids):
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index e4610d36..a8c97622 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -1,8 +1,15 @@
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
from ..libs.llm import LLM
+from ..libs.llm.anthropic import AnthropicLLM
+from ..libs.llm.ggml import GGML
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from ..libs.llm.ollama import Ollama
+from ..libs.llm.openai import OpenAI
+from ..libs.llm.replicate import ReplicateLLM
+from ..libs.llm.together import TogetherLLM
class ContinueSDK(BaseModel):
@@ -18,6 +25,29 @@ ALL_MODEL_ROLES = [
"chat",
]
+MODEL_CLASSES = {
+ cls.__name__: cls
+ for cls in [
+ OpenAI,
+ MaybeProxyOpenAI,
+ GGML,
+ TogetherLLM,
+ AnthropicLLM,
+ ReplicateLLM,
+ Ollama,
+ ]
+}
+
+MODEL_MODULE_NAMES = {
+ "OpenAI": "openai",
+ "MaybeProxyOpenAI": "maybe_proxy_openai",
+ "GGML": "ggml",
+ "TogetherLLM": "together",
+ "AnthropicLLM": "anthropic",
+ "ReplicateLLM": "replicate",
+ "Ollama": "ollama",
+}
+
class Models(BaseModel):
"""Main class that holds the current model configuration"""
@@ -29,6 +59,8 @@ class Models(BaseModel):
edit: Optional[LLM] = None
chat: Optional[LLM] = None
+ unused: List[LLM] = []
+
# TODO namespace these away to not confuse readers,
# or split Models into ModelsConfig, which gets turned into Models
sdk: ContinueSDK = None
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index f2c53e7b..b4548ff2 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,28 +1,33 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import (
+ DEFAULT_ARGS,
+ compile_chat_messages,
+ count_tokens,
+ format_chat_messages,
+)
class GGML(LLM):
# this is model-specific
max_context_length: int = 2048
server_url: str = "http://localhost:8000"
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
- _client_session: aiohttp.ClientSession = None
+ requires_write_log = True
+
+ write_log: Optional[Callable[[str], None]] = None
class Config:
arbitrary_types_allowed = True
- async def start(self, **kwargs):
- self._client_session = aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
- )
+ async def start(self, write_log: Callable[[str], None], **kwargs):
+ self.write_log = write_log
async def stop(self):
await self._client_session.close()
@@ -60,15 +65,24 @@ class GGML(LLM):
system_message=self.system_message,
)
- async with self._client_session.post(
- f"{self.server_url}/v1/completions", json={"messages": messages, **args}
- ) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/v1/completions", json={"messages": messages, **args}
+ ) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ chunk = line.decode("utf-8")
+ yield chunk
+ completion += chunk
+ except:
+ raise Exception(str(line))
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
@@ -86,31 +100,42 @@ class GGML(LLM):
args["stream"] = True
async def generator():
- async with self._client_session.post(
- f"{self.server_url}/v1/chat/completions",
- json={"messages": messages, **args},
- ) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line, end in resp.content.iter_chunks():
- json_chunk = line.decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith(
- "data: [DONE]"
- ):
- continue
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0][
- "delta"
- ] # {"role": "assistant", "content": "..."}
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/v1/chat/completions",
+ json={"messages": messages, **args},
+ ) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line, end in resp.content.iter_chunks():
+ json_chunk = line.decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith(
+ "data: [DONE]"
+ ):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0][
+ "delta"
+ ] # {"role": "assistant", "content": "..."}
# Because quite often the first attempt fails, and it works thereafter
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
try:
async for chunk in generator():
yield chunk
+ if "content" in chunk:
+ completion += chunk["content"]
except:
async for chunk in generator():
yield chunk
+ if "content" in chunk:
+ completion += chunk["content"]
+
+ self.write_log(f"Completion: \n\n{completion}")
async def complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
@@ -127,12 +152,18 @@ class GGML(LLM):
# system_message=self.system_message,
# )
- async with self._client_session.post(
- f"{self.server_url}/v1/completions",
- json={
- "prompt": prompt,
- **args,
- },
- ) as resp:
- text = await resp.text()
- return json.loads(text)["choices"][0]["text"]
+ self.write_log(f"Prompt: \n\n{prompt}")
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/v1/completions",
+ json={
+ "prompt": prompt,
+ **args,
+ },
+ ) as resp:
+ text = await resp.text()
+ completion = json.loads(text)["choices"][0]["text"]
+ self.write_log(f"Completion: \n\n{completion}")
+ return completion
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 76331a28..7abd268d 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import Any, Coroutine, Dict, Generator, List
+from typing import Any, Coroutine, Dict, Generator, List, Optional
import aiohttp
import requests
@@ -15,7 +15,7 @@ class HuggingFaceInferenceAPI(LLM):
hf_token: str
max_context_length: int = 2048
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 276cc290..48e773a3 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -49,7 +49,7 @@ class OpenAI(LLM):
api_key: str
model: str
openai_server_info: Optional[OpenAIServerInfo] = None
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
ca_bundle_path: Optional[str] = None
requires_write_log = True
@@ -73,7 +73,7 @@ class OpenAI(LLM):
if self.openai_server_info.api_version is not None:
openai.api_version = self.openai_server_info.api_version
- if self.verify_ssl is False:
+ if self.verify_ssl is not None and self.verify_ssl is False:
openai.verify_ssl_certs = False
openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index 44f5030c..9a28de2d 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
import aiohttp
@@ -14,7 +14,7 @@ class TogetherLLM(LLM):
model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
max_context_length: int = 2048
base_url: str = "https://api.together.xyz"
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py
index 3c4fb38c..9f119122 100644
--- a/continuedev/src/continuedev/libs/util/commonregex.py
+++ b/continuedev/src/continuedev/libs/util/commonregex.py
@@ -57,7 +57,6 @@ regexes = {
"times": time,
"phones": phone,
"phones_with_exts": phones_with_exts,
- "links": link,
"emails": email,
"ips": ip,
"ipv6s": ipv6,
@@ -78,7 +77,6 @@ placeholders = {
"times": "<TIME>",
"phones": "<PHONE>",
"phones_with_exts": "<PHONE_WITH_EXT>",
- "links": "<LINK>",
"emails": "<EMAIL>",
"ips": "<IP>",
"ipv6s": "<IPV6>",
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index 17ce27ec..7bdffc8e 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -1,5 +1,5 @@
import threading
-from typing import Dict, List
+from typing import Any, Dict, List
import redbaron
@@ -53,6 +53,47 @@ def edit_config_property(key_path: List[str], value: redbaron.RedBaron):
file.write(red.dumps())
+def add_config_import(line: str):
+ with edit_lock:
+ red = load_red()
+ # check if the import already exists
+ for node in red:
+ if node.type == "import" and node.dumps() == line:
+ return
+ # if it doesn't exist, add it
+ red.insert(1, line)
+
+ with open(getConfigFilePath(), "w") as file:
+ file.write(red.dumps())
+
+
+filtered_attrs = {
+ "requires_api_key",
+ "requires_unique_id",
+ "requires_write_log",
+ "class_name",
+ "name",
+ "llm",
+}
+
+
+def display_val(v: Any):
+ if isinstance(v, str):
+ return f'"{v}"'
+ return str(v)
+
+
+def display_llm_class(llm):
+ args = ", ".join(
+ [
+ f"{k}={display_val(v)}"
+ for k, v in llm.dict().items()
+ if k not in filtered_attrs and v is not None
+ ]
+ )
+ return f"{llm.__class__.__name__}({args})"
+
+
def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
args = [f"{key}={value}" for key, value in args.items()]
return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0]
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index 63548698..ad09f193 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -95,7 +95,8 @@ class SimpleChatStep(Step):
self.name = remove_quotes_and_escapes(
await sdk.models.medium.complete(
- f"{self.description}\n\nHere is a short title for the above chat message:"
+ f"{self.description}\n\nHere is a short title for the above chat message (no more than 10 words):",
+ max_tokens=20,
)
)
diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py
new file mode 100644
index 00000000..1c50c714
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/steps/setup_model.py
@@ -0,0 +1,24 @@
+from ...core.main import Step
+from ...core.sdk import ContinueSDK
+from ...libs.util.paths import getConfigFilePath
+
+MODEL_CLASS_TO_MESSAGE = {
+ "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.",
+ "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).",
+ "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",
+ "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.",
+}
+
+
+class SetupModelStep(Step):
+ model_class: str
+ name: str = "Setup model in config.py"
+
+ async def run(self, sdk: ContinueSDK):
+ await sdk.ide.setFileOpen(getConfigFilePath())
+ self.description = MODEL_CLASS_TO_MESSAGE.get(
+ self.model_class, "Please finish setting up this model in `config.py`"
+ )
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index bdcaad47..55a5f3b4 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
from ..core.main import ContextItem
+from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
from ..libs.util.create_async_task import create_async_task
from ..libs.util.edit_config import (
+ add_config_import,
create_float_node,
create_obj_node,
create_string_node,
+ display_llm_class,
)
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
+from ..plugins.steps.setup_model import SetupModelStep
from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager
@@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
)
def set_model_for_role(self, role: str, model_class: str, model: Any):
- prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role)
- if prev_model is not None:
- prev_model.update(model)
- self.session.autopilot.continue_sdk.models.__setattr__(role, model)
- create_async_task(
- self.session.autopilot.set_config_attr(
- ["models", role], create_obj_node(model_class, {**model})
- ),
- self.on_error,
- )
+ models = self.session.autopilot.continue_sdk.config.models
+ unused_models = models.unused
+ if role == "*":
+
+ async def async_stuff():
+ # Clear all of the models and store them in unused_models
+ # NOTE: There will be duplicates
+ for role in ALL_MODEL_ROLES:
+ prev_model = models.__getattribute__(role)
+ models.__setattr__(role, None)
+ if prev_model is not None:
+ exists = False
+ for other in unused_models:
+ if display_llm_class(prev_model) == display_llm_class(
+ other
+ ):
+ exists = True
+ break
+ if not exists:
+ unused_models.append(prev_model)
+
+ # Replace default with either new one or existing from unused_models
+ for unused_model in unused_models:
+ if model_class == unused_model.__class__.__name__ and (
+ "model" not in model or model["model"] == unused_model.model
+ ):
+ models.default = unused_model
+
+ # Set and start the default model if didn't already exist from unused
+ if models.default is None:
+ models.default = MODEL_CLASSES[model_class](**model)
+ await self.session.autopilot.continue_sdk.start_model(
+ models.default
+ )
+ await self.session.autopilot.continue_sdk.run_step(
+ SetupModelStep(model_class=model_class)
+ )
+
+ models_args = {}
+
+ for role in ALL_MODEL_ROLES:
+ val = models.__getattribute__(role)
+ if val is None:
+ continue # no pun intended
+
+ models_args[role] = display_llm_class(val)
+
+ JOINER = ", "
+ models_args[
+ "unused"
+ ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]"
+
+ await self.session.autopilot.set_config_attr(
+ ["models"],
+ create_obj_node("Models", models_args),
+ )
+
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
+ )
+
+ for role in ALL_MODEL_ROLES:
+ if role != "default":
+ models.__setattr__(role, models.default)
+
+ create_async_task(async_stuff(), self.on_error)
+ else:
+ # TODO
+ pass
def save_context_group(self, title: str, context_items: List[ContextItem]):
create_async_task(
diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx
index c0f0929b..897fd683 100644
--- a/extension/react-app/src/components/Layout.tsx
+++ b/extension/react-app/src/components/Layout.tsx
@@ -22,6 +22,7 @@ import {
} from "@heroicons/react/24/outline";
import HeaderButtonWithText from "./HeaderButtonWithText";
import { useNavigate } from "react-router-dom";
+import ModelSelect from "./ModelSelect";
// #region Styled Components
@@ -138,30 +139,29 @@ const Layout = () => {
{bottomMessage}
</BottomMessageDiv>
<Footer>
- <SparklesIcon
- visibility={
- localStorage.getItem("hideFeature") === "true"
- ? "hidden"
- : "visible"
- }
- className="mr-auto cursor-pointer"
- onClick={() => {
- localStorage.setItem("hideFeature", "true");
- }}
- onMouseEnter={() => {
- dispatch(
- setBottomMessage(
- "🎁 New Feature: Use ⌘⇧R automatically debug errors in the terminal"
- )
- );
- }}
- onMouseLeave={() => {
- dispatch(setBottomMessage(undefined));
- }}
- width="1.3em"
- height="1.3em"
- color="yellow"
- />
+ {localStorage.getItem("hideFeature") === "true" || (
+ <SparklesIcon
+ className="mr-auto cursor-pointer"
+ onClick={() => {
+ localStorage.setItem("hideFeature", "true");
+ }}
+ onMouseEnter={() => {
+ dispatch(
+ setBottomMessage(
+ "🎁 New Feature: Use ⌘⇧R automatically debug errors in the terminal (you can click the sparkle icon to make it go away)"
+ )
+ );
+ }}
+ onMouseLeave={() => {
+ dispatch(setBottomMessage(undefined));
+ }}
+ width="1.3em"
+ height="1.3em"
+ color="yellow"
+ />
+ )}
+
+ <ModelSelect />
<HeaderButtonWithText
onClick={() => {
client?.loadSession(undefined);
diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx
new file mode 100644
index 00000000..49788143
--- /dev/null
+++ b/extension/react-app/src/components/ModelSelect.tsx
@@ -0,0 +1,113 @@
+import styled from "styled-components";
+import {
+ defaultBorderRadius,
+ secondaryDark,
+ vscBackground,
+ vscForeground,
+} from ".";
+import { useContext, useEffect } from "react";
+import { GUIClientContext } from "../App";
+import { RootStore } from "../redux/store";
+import { useSelector } from "react-redux";
+
+const MODEL_INFO: { title: string; class: string; args: any }[] = [
+ {
+ title: "gpt-4",
+ class: "MaybeProxyOpenAI",
+ args: {
+ model: "gpt-4",
+ api_key: "",
+ },
+ },
+ {
+ title: "gpt-3.5-turbo",
+ class: "MaybeProxyOpenAI",
+ args: {
+ model: "gpt-3.5-turbo",
+ api_key: "",
+ },
+ },
+ {
+ title: "claude-2",
+ class: "AnthropicLLM",
+ args: {
+ model: "claude-2",
+ api_key: "<ANTHROPIC_API_KEY>",
+ },
+ },
+ {
+ title: "GGML",
+ class: "GGML",
+ args: {},
+ },
+ {
+ title: "Ollama",
+ class: "Ollama",
+ args: {
+ model: "codellama",
+ },
+ },
+ {
+ title: "Replicate",
+ class: "ReplicateLLM",
+ args: {
+ model:
+ "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781",
+ api_key: "<REPLICATE_API_KEY>",
+ },
+ },
+ {
+ title: "TogetherAI",
+ class: "TogetherLLM",
+ args: {
+ model: "gpt-4",
+ api_key: "<TOGETHER_API_KEY>",
+ },
+ },
+];
+
+const Select = styled.select`
+ border: none;
+ width: fit-content;
+ background-color: ${secondaryDark};
+ color: ${vscForeground};
+ border-radius: ${defaultBorderRadius};
+ padding: 6px;
+ /* box-shadow: 0px 0px 1px 0px ${vscForeground}; */
+ max-height: 35vh;
+ overflow: scroll;
+ cursor: pointer;
+ margin-right: auto;
+
+ &:focus {
+ outline: none;
+ }
+`;
+
+function ModelSelect(props: {}) {
+ const client = useContext(GUIClientContext);
+ const defaultModel = useSelector(
+ (state: RootStore) =>
+ (state.serverState.config as any)?.models?.default?.class_name
+ );
+
+ return (
+ <Select
+ defaultValue={0}
+ onChange={(e) => {
+ const model = MODEL_INFO[parseInt(e.target.value)];
+ client?.setModelForRole("*", model.class, model.args);
+ }}
+ >
+ {MODEL_INFO.map((model, idx) => {
+ return (
+ <option selected={defaultModel === model.class} value={idx}>
+ {model.title}
+ </option>
+ );
+ })}
+ </Select>
+ );
+}
+
+export default ModelSelect;