summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-26 10:10:06 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-26 10:10:06 -0700
commit044b7caa6b26a5d78ae52faa0ae675abc8c4e161 (patch)
tree5727da03069cc6bf83b9dd385003f615124e6fe8 /continuedev
parent631d141dbd26edb0de3e0e3ed194dbfd3641059f (diff)
downloadsncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.gz
sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.tar.bz2
sncontinue-044b7caa6b26a5d78ae52faa0ae675abc8c4e161.zip
feat: :sparkles: select model from dropdown
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/models.py34
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py4
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py43
-rw-r--r--continuedev/src/continuedev/plugins/steps/setup_model.py24
-rw-r--r--continuedev/src/continuedev/server/gui.py83
8 files changed, 179 insertions, 19 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index e4610d36..a8c97622 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -1,8 +1,15 @@
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
from ..libs.llm import LLM
+from ..libs.llm.anthropic import AnthropicLLM
+from ..libs.llm.ggml import GGML
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from ..libs.llm.ollama import Ollama
+from ..libs.llm.openai import OpenAI
+from ..libs.llm.replicate import ReplicateLLM
+from ..libs.llm.together import TogetherLLM
class ContinueSDK(BaseModel):
@@ -18,6 +25,29 @@ ALL_MODEL_ROLES = [
"chat",
]
+MODEL_CLASSES = {
+ cls.__name__: cls
+ for cls in [
+ OpenAI,
+ MaybeProxyOpenAI,
+ GGML,
+ TogetherLLM,
+ AnthropicLLM,
+ ReplicateLLM,
+ Ollama,
+ ]
+}
+
+MODEL_MODULE_NAMES = {
+ "OpenAI": "openai",
+ "MaybeProxyOpenAI": "maybe_proxy_openai",
+ "GGML": "ggml",
+ "TogetherLLM": "together",
+ "AnthropicLLM": "anthropic",
+ "ReplicateLLM": "replicate",
+ "Ollama": "ollama",
+}
+
class Models(BaseModel):
"""Main class that holds the current model configuration"""
@@ -29,6 +59,8 @@ class Models(BaseModel):
edit: Optional[LLM] = None
chat: Optional[LLM] = None
+ unused: List[LLM] = []
+
# TODO namespace these away to not confuse readers,
# or split Models into ModelsConfig, which gets turned into Models
sdk: ContinueSDK = None
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 34c3ab74..b4548ff2 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -17,7 +17,7 @@ class GGML(LLM):
# this is model-specific
max_context_length: int = 2048
server_url: str = "http://localhost:8000"
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
requires_write_log = True
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 76331a28..7abd268d 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import Any, Coroutine, Dict, Generator, List
+from typing import Any, Coroutine, Dict, Generator, List, Optional
import aiohttp
import requests
@@ -15,7 +15,7 @@ class HuggingFaceInferenceAPI(LLM):
hf_token: str
max_context_length: int = 2048
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 276cc290..48e773a3 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -49,7 +49,7 @@ class OpenAI(LLM):
api_key: str
model: str
openai_server_info: Optional[OpenAIServerInfo] = None
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
ca_bundle_path: Optional[str] = None
requires_write_log = True
@@ -73,7 +73,7 @@ class OpenAI(LLM):
if self.openai_server_info.api_version is not None:
openai.api_version = self.openai_server_info.api_version
- if self.verify_ssl is False:
+ if self.verify_ssl is not None and self.verify_ssl is False:
openai.verify_ssl_certs = False
openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index 44f5030c..9a28de2d 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
import aiohttp
@@ -14,7 +14,7 @@ class TogetherLLM(LLM):
model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
max_context_length: int = 2048
base_url: str = "https://api.together.xyz"
- verify_ssl: bool = True
+ verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index 17ce27ec..7bdffc8e 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -1,5 +1,5 @@
import threading
-from typing import Dict, List
+from typing import Any, Dict, List
import redbaron
@@ -53,6 +53,47 @@ def edit_config_property(key_path: List[str], value: redbaron.RedBaron):
file.write(red.dumps())
+def add_config_import(line: str):
+ with edit_lock:
+ red = load_red()
+ # check if the import already exists
+ for node in red:
+ if node.type == "import" and node.dumps() == line:
+ return
+ # if it doesn't exist, add it
+ red.insert(1, line)
+
+ with open(getConfigFilePath(), "w") as file:
+ file.write(red.dumps())
+
+
+filtered_attrs = {
+ "requires_api_key",
+ "requires_unique_id",
+ "requires_write_log",
+ "class_name",
+ "name",
+ "llm",
+}
+
+
+def display_val(v: Any):
+ if isinstance(v, str):
+ return f'"{v}"'
+ return str(v)
+
+
+def display_llm_class(llm):
+ args = ", ".join(
+ [
+ f"{k}={display_val(v)}"
+ for k, v in llm.dict().items()
+ if k not in filtered_attrs and v is not None
+ ]
+ )
+ return f"{llm.__class__.__name__}({args})"
+
+
def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
args = [f"{key}={value}" for key, value in args.items()]
return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0]
diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py
new file mode 100644
index 00000000..1c50c714
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/steps/setup_model.py
@@ -0,0 +1,24 @@
+from ...core.main import Step
+from ...core.sdk import ContinueSDK
+from ...libs.util.paths import getConfigFilePath
+
+MODEL_CLASS_TO_MESSAGE = {
+ "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.",
+ "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.",
+ "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).",
+ "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",
+ "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.",
+}
+
+
+class SetupModelStep(Step):
+ model_class: str
+ name: str = "Setup model in config.py"
+
+ async def run(self, sdk: ContinueSDK):
+ await sdk.ide.setFileOpen(getConfigFilePath())
+ self.description = MODEL_CLASS_TO_MESSAGE.get(
+ self.model_class, "Please finish setting up this model in `config.py`"
+ )
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index bdcaad47..55a5f3b4 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -9,16 +9,20 @@ from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
from ..core.main import ContextItem
+from ..core.models import ALL_MODEL_ROLES, MODEL_CLASSES, MODEL_MODULE_NAMES
from ..libs.util.create_async_task import create_async_task
from ..libs.util.edit_config import (
+ add_config_import,
create_float_node,
create_obj_node,
create_string_node,
+ display_llm_class,
)
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
+from ..plugins.steps.setup_model import SetupModelStep
from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager
@@ -227,16 +231,75 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
)
def set_model_for_role(self, role: str, model_class: str, model: Any):
- prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role)
- if prev_model is not None:
- prev_model.update(model)
- self.session.autopilot.continue_sdk.models.__setattr__(role, model)
- create_async_task(
- self.session.autopilot.set_config_attr(
- ["models", role], create_obj_node(model_class, {**model})
- ),
- self.on_error,
- )
+ models = self.session.autopilot.continue_sdk.config.models
+ unused_models = models.unused
+ if role == "*":
+
+ async def async_stuff():
+ # Clear all of the models and store them in unused_models
+ # NOTE: There will be duplicates
+ for role in ALL_MODEL_ROLES:
+ prev_model = models.__getattribute__(role)
+ models.__setattr__(role, None)
+ if prev_model is not None:
+ exists = False
+ for other in unused_models:
+ if display_llm_class(prev_model) == display_llm_class(
+ other
+ ):
+ exists = True
+ break
+ if not exists:
+ unused_models.append(prev_model)
+
+ # Replace default with either new one or existing from unused_models
+ for unused_model in unused_models:
+ if model_class == unused_model.__class__.__name__ and (
+ "model" not in model or model["model"] == unused_model.model
+ ):
+ models.default = unused_model
+
+ # Set and start the default model if didn't already exist from unused
+ if models.default is None:
+ models.default = MODEL_CLASSES[model_class](**model)
+ await self.session.autopilot.continue_sdk.start_model(
+ models.default
+ )
+ await self.session.autopilot.continue_sdk.run_step(
+ SetupModelStep(model_class=model_class)
+ )
+
+ models_args = {}
+
+ for role in ALL_MODEL_ROLES:
+ val = models.__getattribute__(role)
+ if val is None:
+ continue # no pun intended
+
+ models_args[role] = display_llm_class(val)
+
+ JOINER = ", "
+ models_args[
+ "unused"
+ ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]"
+
+ await self.session.autopilot.set_config_attr(
+ ["models"],
+ create_obj_node("Models", models_args),
+ )
+
+ add_config_import(
+ f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}"
+ )
+
+ for role in ALL_MODEL_ROLES:
+ if role != "default":
+ models.__setattr__(role, models.default)
+
+ create_async_task(async_stuff(), self.on_error)
+ else:
+ # TODO
+ pass
def save_context_group(self, title: str, context_items: List[ContextItem]):
create_async_task(