From 5eec484dc79bb56dabf9a56af0dbe6bc95227d39 Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Tue, 22 Aug 2023 13:12:58 -0700 Subject: Config UI (#399) * feat: :sparkles: UI for config! * feat: :sparkles: (latent) edit models in settings --- continuedev/poetry.lock | 58 ++- continuedev/pyproject.toml | 1 + continuedev/src/continuedev/core/autopilot.py | 7 + continuedev/src/continuedev/core/config.py | 8 + continuedev/src/continuedev/core/main.py | 22 +- continuedev/src/continuedev/core/models.py | 94 ++--- continuedev/src/continuedev/libs/llm/__init__.py | 8 + .../src/continuedev/libs/llm/maybe_proxy_openai.py | 12 +- continuedev/src/continuedev/libs/llm/openai.py | 1 - .../src/continuedev/libs/util/edit_config.py | 75 ++++ .../src/continuedev/models/generate_json_schema.py | 11 +- continuedev/src/continuedev/plugins/steps/chat.py | 2 +- .../src/continuedev/plugins/steps/core/core.py | 8 +- continuedev/src/continuedev/server/gui.py | 43 +++ docs/docs/customization.md | 6 +- extension/react-app/package-lock.json | 23 +- extension/react-app/package.json | 3 +- extension/react-app/src/App.tsx | 5 + .../src/components/HeaderButtonWithText.tsx | 4 +- extension/react-app/src/components/InfoHover.tsx | 19 + extension/react-app/src/components/Layout.tsx | 9 + .../react-app/src/components/ModelSettings.tsx | 107 ++++++ extension/react-app/src/components/index.ts | 50 ++- .../src/hooks/AbstractContinueGUIClientProtocol.ts | 10 + .../src/hooks/ContinueGUIClientProtocol.ts | 12 + extension/react-app/src/pages/settings.tsx | 229 ++++++++++++ .../src/redux/slices/serverStateReducer.ts | 4 + extension/schema/ContinueConfig.d.ts | 175 +++++++++ extension/schema/FullState.d.ts | 8 + extension/schema/LLM.d.ts | 20 + extension/schema/Models.d.ts | 36 ++ schema/json/ContinueConfig.json | 412 +++++++++++++++++++++ schema/json/FullState.json | 24 +- schema/json/LLM.json | 30 ++ schema/json/Models.json | 66 ++++ 35 files changed, 1521 insertions(+), 81 deletions(-) create mode 100644 continuedev/src/continuedev/libs/util/edit_config.py create mode 100644 extension/react-app/src/components/InfoHover.tsx create mode 100644 extension/react-app/src/components/ModelSettings.tsx create mode 100644 extension/react-app/src/pages/settings.tsx create mode 100644 extension/schema/ContinueConfig.d.ts create mode 100644 extension/schema/LLM.d.ts create mode 100644 extension/schema/Models.d.ts create mode 100644 schema/json/ContinueConfig.json create mode 100644 schema/json/LLM.json create mode 100644 schema/json/Models.json diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index d3140756..aefc7cf9 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -172,6 +172,17 @@ doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] trio = ["trio (>=0.16,<0.22)"] +[[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = "*" +files = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] + [[package]] name = "async-timeout" version = "4.0.2" @@ -212,6 +223,20 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "baron" +version = "0.10.1" +description = "Full Syntax Tree for python to make writing refactoring code a realist task" +optional = false +python-versions = "*" +files = [ + {file = "baron-0.10.1-py2.py3-none-any.whl", hash = "sha256:befb33f4b9e832c7cd1e3cf0eafa6dd3cb6ed4cb2544245147c019936f4e0a8a"}, + {file = "baron-0.10.1.tar.gz", hash = "sha256:af822ad44d4eb425c8516df4239ac4fdba9fdb398ef77e4924cd7c9b4045bc2f"}, +] + +[package.dependencies] +rply = "*" + [[package]] name = "beautifulsoup4" version = "4.12.2" @@ -1188,6 +1213,23 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "redbaron" +version = "0.9.2" +description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task" +optional = false +python-versions = "*" +files = [ + {file = "redbaron-0.9.2-py2.py3-none-any.whl", hash = "sha256:d01032b6a848b5521a8d6ef72486315c2880f420956870cdd742e2b5a09b9bab"}, + {file = "redbaron-0.9.2.tar.gz", hash = "sha256:472d0739ca6b2240bb2278ae428604a75472c9c12e86c6321e8c016139c0132f"}, +] + +[package.dependencies] +baron = ">=0.7" + +[package.extras] +notebook = ["pygments"] + [[package]] name = "regex" version = "2023.5.5" @@ -1335,6 +1377,20 @@ files = [ {file = "ripgrepy-2.0.0.tar.gz", hash = "sha256:6dd871bafe859301097354d1f171540fbc9bd38d3f8f52f8a196dc28522085da"}, ] +[[package]] +name = "rply" +version = "0.7.8" +description = "A pure Python Lex/Yacc that works with RPython" +optional = false +python-versions = "*" +files = [ + {file = "rply-0.7.8-py2.py3-none-any.whl", hash = "sha256:28ffd11d656c48aeb8c508eb382acd6a0bd906662624b34388751732a27807e7"}, + {file = "rply-0.7.8.tar.gz", hash = "sha256:2a808ac25a4580a9991fc304d64434e299a8fc75760574492f242cbb5bb301c9"}, +] + +[package.dependencies] +appdirs = "*" + [[package]] name = "six" version = "1.16.0" @@ -1849,4 +1905,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "5500ea86b06a96f5fe45939500936911e622043a67a3a5c3d02473463ff2fd6c" +content-hash = "fe4715494ed91c691ec1eb914373a612e75751e6685678e438b73193879de98d" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 90ff0572..8cdf1197 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -31,6 +31,7 @@ socksio = "^1.0.0" ripgrepy = "^2.0.0" bs4 = "^0.0.1" replicate = "^0.11.0" +redbaron = "^0.9.2" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 7b0661a5..a1b21903 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -5,11 +5,13 @@ import traceback from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional +import redbaron from aiohttp import ClientPayloadError from openai import error as openai_errors from pydantic import root_validator from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import edit_config_property from ..libs.util.logging import logger from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue @@ -158,6 +160,7 @@ class Autopilot(ContinueBaseModel): if self.context_manager is not None else [], session_info=self.session_info, + config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, ) self.full_state = full_state @@ -542,6 +545,10 @@ class Autopilot(ContinueBaseModel): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): + edit_config_property(key_path, value) + await self.update_subscribers() + _saved_context_groups: Dict[str, List[ContextItem]] = {} def _persist_context_groups(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index f5bf81fb..62e9c690 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -14,6 +14,14 @@ class SlashCommand(BaseModel): step: Type[Step] params: Optional[Dict] = {} + def dict(self, *args, **kwargs): + return { + "name": self.name, + "description": self.description, + "params": self.params, + "step": self.step.__name__, + } + class CustomCommand(BaseModel): name: str diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index e4ee7668..bf098be9 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -277,6 +277,19 @@ class SessionInfo(ContinueBaseModel): date_created: str +class ContinueConfig(ContinueBaseModel): + system_message: str + temperature: float + + class Config: + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("policy", None) + return original_dict + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" @@ -287,19 +300,16 @@ class FullState(ContinueBaseModel): adding_highlighted_code: bool selected_context_items: List[ContextItem] session_info: Optional[SessionInfo] = None + config: ContinueConfig saved_context_groups: Dict[str, List[ContextItem]] = {} class ContinueSDK: - pass + ... class Models: - pass - - -class ContinueConfig: - pass + ... class Policy(ContinueBaseModel): diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 52a52b1d..e4610d36 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,24 @@ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel from ..libs.llm import LLM +class ContinueSDK(BaseModel): + pass + + +ALL_MODEL_ROLES = [ + "default", + "small", + "medium", + "large", + "edit", + "chat", +] + + class Models(BaseModel): """Main class that holds the current model configuration""" @@ -12,57 +26,47 @@ class Models(BaseModel): small: Optional[LLM] = None medium: Optional[LLM] = None large: Optional[LLM] = None + edit: Optional[LLM] = None + chat: Optional[LLM] = None # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models - sdk: "ContinueSDK" = None - system_message: Any = None - - """ - Better to have sdk.llm.stream_chat(messages, model="claude-2"). - Then you also don't care that it' async. - And it's easier to add more models. - And intermediate shared code is easier to add. - And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" - PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. - Easy to reason about, can place anywhere. - And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. - This can all happen inside of Models? - - class Prompt: - def __init__(self, ...info): - '''take whatever info is needed to describe the prompt''' - - def to_string(self, model: str) -> str: - '''depending on the model, return the single prompt string''' - """ + sdk: ContinueSDK = None + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("sdk", None) + return original_dict + + @property + def all_models(self): + models = [getattr(self, role) for role in ALL_MODEL_ROLES] + return [model for model in models if model is not None] + + @property + def system_message(self) -> Optional[str]: + if self.sdk: + return self.sdk.config.system_message + return None + + def set_system_message(self, msg: str): + for model in self.all_models: + model.system_message = msg async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" self.sdk = sdk - self.system_message = self.sdk.config.system_message - await sdk.start_model(self.default) - if self.small: - await sdk.start_model(self.small) - else: - self.small = self.default - - if self.medium: - await sdk.start_model(self.medium) - else: - self.medium = self.default - - if self.large: - await sdk.start_model(self.large) - else: - self.large = self.default + + for role in ALL_MODEL_ROLES: + model = getattr(self, role) + if model is None: + setattr(self, role, self.default) + else: + await sdk.start_model(model) + + self.set_system_message(self.system_message) async def stop(self, sdk: "ContinueSDK"): """Stop each LLM (if it's not the default, which is shared)""" - await self.default.stop() - if self.small is not self.default: - await self.small.stop() - if self.medium is not self.default: - await self.medium.stop() - if self.large is not self.default: - await self.large.stop() + for model in self.all_models: + await model.stop() diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4af6b8e2..294e2c8b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -14,6 +14,14 @@ class LLM(ContinueBaseModel, ABC): class Config: arbitrary_types_allowed = True + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("write_log", None) + original_dict["name"] = self.name + original_dict["class_name"] = self.__class__.__name__ + return original_dict @abstractproperty def name(self): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index 65e5db3a..daffe41f 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -16,9 +16,16 @@ class MaybeProxyOpenAI(LLM): llm: Optional[LLM] = None + def update_llm_properties(self): + if self.llm is not None: + self.llm.system_message = self.system_message + @property def name(self): - return self.llm.name + if self.llm is not None: + return self.llm.name + else: + return None @property def context_length(self): @@ -44,11 +51,13 @@ class MaybeProxyOpenAI(LLM): async def complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: + self.update_llm_properties() return await self.llm.complete(prompt, with_history=with_history, **kwargs) async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: + self.update_llm_properties() resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs) async for item in resp: yield item @@ -56,6 +65,7 @@ class MaybeProxyOpenAI(LLM): async def stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: + self.update_llm_properties() resp = self.llm.stream_chat(messages=messages, **kwargs) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c2d86841..276cc290 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -54,7 +54,6 @@ class OpenAI(LLM): requires_write_log = True - system_message: Optional[str] = None write_log: Optional[Callable[[str], None]] = None async def start( diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py new file mode 100644 index 00000000..5c070bb4 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -0,0 +1,75 @@ +import threading +from typing import Dict, List + +import redbaron + +config_file_path = "/Users/natesesti/.continue/config.py" + + +def load_red(): + with open(config_file_path, "r") as file: + source_code = file.read() + + red = redbaron.RedBaron(source_code) + return red + + +def get_config_node(red): + for node in red: + if node.type == "assignment" and node.target.value == "config": + return node + else: + raise Exception("Config file appears to be improperly formatted") + + +def edit_property( + args: redbaron.RedBaron, key_path: List[str], value: redbaron.RedBaron +): + for i in range(len(args)): + node = args[i] + if node.type != "call_argument": + continue + + if node.target.value == key_path[0]: + if len(key_path) > 1: + edit_property(node.value.value[1].value, key_path[1:], value) + else: + args[i].value = value + return + + +edit_lock = threading.Lock() + + +def edit_config_property(key_path: List[str], value: redbaron.RedBaron): + with edit_lock: + red = load_red() + config = get_config_node(red) + config_args = config.value.value[1].value + edit_property(config_args, key_path, value) + + with open(config_file_path, "w") as file: + file.write(red.dumps()) + + +def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: + args = [f"{key}={value}" for key, value in args.items()] + return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0] + + +def create_string_node(string: str) -> redbaron.RedBaron: + return redbaron.RedBaron(f'"{string}"')[0] + + +def create_float_node(float: float) -> redbaron.RedBaron: + return redbaron.RedBaron(f"{float}")[0] + + +# Example: +# edit_config_property( +# [ +# "models", +# "default", +# ], +# create_obj_node("OpenAI", {"api_key": '""', "model": '"gpt-4"'}), +# ) diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 1c43f0a0..ad727f06 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -2,18 +2,24 @@ import os from pydantic import schema_json_of +from ..core.config import ContinueConfig from ..core.context import ContextItem from ..core.main import FullState, History, HistoryNode, SessionInfo +from ..core.models import Models +from ..libs.llm import LLM from .filesystem import FileEdit, RangeInFile from .filesystem_edit import FileEditWithFullContents -from .main import * +from .main import Position, Range, Traceback, TracebackFrame MODELS_TO_GENERATE = ( [Position, Range, Traceback, TracebackFrame] + [RangeInFile, FileEdit] + [FileEditWithFullContents] + [History, HistoryNode, FullState, SessionInfo] + + [ContinueConfig] + [ContextItem] + + [Models] + + [LLM] ) RENAMES = {"ExampleClass": "RenamedName"} @@ -34,7 +40,10 @@ def main(): try: json = schema_json_of(model, indent=2, title=title) except Exception as e: + import traceback + print(f"Failed to generate json schema for {title}: {e}") + traceback.print_exc() continue # pun intended with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 7f2ebef1..63548698 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -76,7 +76,7 @@ class SimpleChatStep(Step): messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.default.stream_chat( + generator = sdk.models.chat.stream_chat( messages, temperature=sdk.config.temperature ) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 86569dcb..3de76eaf 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -8,11 +8,7 @@ from typing import Any, Coroutine, List, Union from pydantic import validator from ....core.main import ChatMessage, ContinueCustomException, Step -from ....core.observation import ( - Observation, - TextObservation, - UserInputObservation, -) +from ....core.observation import Observation, TextObservation, UserInputObservation from ....libs.llm.ggml import GGML from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS @@ -246,7 +242,7 @@ class DefaultModelEditCodeStep(Step): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.default + model_to_use = sdk.models.edit max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 5589284a..bdcaad47 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,11 @@ from uvicorn.main import Server from ..core.main import ContextItem from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import ( + create_float_node, + create_obj_node, + create_string_node, +) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger @@ -105,6 +110,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": self.edit_step_at_index(data.get("user_input", ""), data["index"]) + elif message_type == "set_system_message": + self.set_system_message(data["message"]) + elif message_type == "set_temperature": + self.set_temperature(float(data["temperature"])) + elif message_type == "set_model_for_role": + self.set_model_for_role(data["role"], data["model_class"], data["model"]) elif message_type == "save_context_group": self.save_context_group( data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -195,6 +206,38 @@ class GUIProtocolServer(AbstractGUIProtocolServer): posthog_logger.capture_event("load_session", {"session_id": session_id}) + def set_system_message(self, message: str): + self.session.autopilot.continue_sdk.config.system_message = message + self.session.autopilot.continue_sdk.models.set_system_message(message) + + create_async_task( + self.session.autopilot.set_config_attr( + ["system_message"], create_string_node(message) + ), + self.on_error, + ) + + def set_temperature(self, temperature: float): + self.session.autopilot.continue_sdk.config.temperature = temperature + create_async_task( + self.session.autopilot.set_config_attr( + ["temperature"], create_float_node(temperature) + ), + self.on_error, + ) + + def set_model_for_role(self, role: str, model_class: str, model: Any): + prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) + if prev_model is not None: + prev_model.update(model) + self.session.autopilot.continue_sdk.models.__setattr__(role, model) + create_async_task( + self.session.autopilot.set_config_attr( + ["models", role], create_obj_node(model_class, {**model}) + ), + self.on_error, + ) + def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( self.session.autopilot.save_context_group(title, context_items), diff --git a/docs/docs/customization.md b/docs/docs/customization.md index b7279fe3..3a407510 100644 --- a/docs/docs/customization.md +++ b/docs/docs/customization.md @@ -7,7 +7,7 @@ Continue can be deeply customized by editing the `ContinueConfig` object in `~/. In `config.py`, you'll find the `models` property: ```python -from continuedev.src.continuedev.core.sdk import Models +from continuedev.src.continuedev.core.models import Models config = ContinueConfig( ... @@ -103,7 +103,7 @@ config = ContinueConfig( The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this: ```python -from continuedev.src.continuedev.core.sdk import Models +from continuedev.src.continuedev.core.models import Models from continuedev.src.continuedev.libs.llm.together import TogetherLLM config = ContinueConfig( @@ -122,7 +122,7 @@ config = ContinueConfig( Replicate is a great option for newly released language models or models that you've deployed through their platform. Sign up for an account [here](https://replicate.ai/), copy your API key, and then select any model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models). Change `~/.continue/config.py` to look like this: ```python -from continuedev.src.continuedev.core.sdk import Models +from continuedev.src.continuedev.core.models import Models from continuedev.src.continuedev.libs.llm.replicate import ReplicateLLM config = ContinueConfig( diff --git a/extension/react-app/package-lock.json b/extension/react-app/package-lock.json index c2265d15..fb68081c 100644 --- a/extension/react-app/package-lock.json +++ b/extension/react-app/package-lock.json @@ -6,7 +6,6 @@ "packages": { "": { "name": "react-app", - "version": "0.0.0", "dependencies": { "@heroicons/react": "^2.0.18", "@types/vscode-webview": "^1.57.1", @@ -17,6 +16,7 @@ "prismjs": "^1.29.0", "react": "^18.2.0", "react-dom": "^18.2.0", + "react-hook-form": "^7.45.4", "react-redux": "^8.0.5", "react-router-dom": "^6.14.2", "react-switch": "^7.0.0", @@ -3748,6 +3748,21 @@ "react": "^18.2.0" } }, + "node_modules/react-hook-form": { + "version": "7.45.4", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.45.4.tgz", + "integrity": "sha512-HGDV1JOOBPZj10LB3+OZgfDBTn+IeEsNOKiq/cxbQAIbKaiJUe/KV8DBUzsx0Gx/7IG/orWqRRm736JwOfUSWQ==", + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/react-hook-form" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17 || ^18" + } + }, "node_modules/react-is": { "version": "18.2.0", "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", @@ -7335,6 +7350,12 @@ "scheduler": "^0.23.0" } }, + "react-hook-form": { + "version": "7.45.4", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.45.4.tgz", + "integrity": "sha512-HGDV1JOOBPZj10LB3+OZgfDBTn+IeEsNOKiq/cxbQAIbKaiJUe/KV8DBUzsx0Gx/7IG/orWqRRm736JwOfUSWQ==", + "requires": {} + }, "react-is": { "version": "18.2.0", "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", diff --git a/extension/react-app/package.json b/extension/react-app/package.json index 23cdf9bb..b9f70645 100644 --- a/extension/react-app/package.json +++ b/extension/react-app/package.json @@ -17,6 +17,7 @@ "prismjs": "^1.29.0", "react": "^18.2.0", "react-dom": "^18.2.0", + "react-hook-form": "^7.45.4", "react-redux": "^8.0.5", "react-router-dom": "^6.14.2", "react-switch": "^7.0.0", @@ -38,4 +39,4 @@ "typescript": "^4.9.3", "vite": "^4.1.0" } -} \ No newline at end of file +} diff --git a/extension/react-app/src/App.tsx b/extension/react-app/src/App.tsx index 05b322ff..65ad1ddd 100644 --- a/extension/react-app/src/App.tsx +++ b/extension/react-app/src/App.tsx @@ -17,6 +17,7 @@ import { setHighlightedCode } from "./redux/slices/miscSlice"; import { postVscMessage } from "./vscode"; import { createBrowserRouter, RouterProvider } from "react-router-dom"; import ErrorPage from "./pages/error"; +import SettingsPage from "./pages/settings"; const router = createBrowserRouter([ { @@ -36,6 +37,10 @@ const router = createBrowserRouter([ path: "/history", element: , }, + { + path: "/settings", + element: , + }, ], }, ]); diff --git a/extension/react-app/src/components/HeaderButtonWithText.tsx b/extension/react-app/src/components/HeaderButtonWithText.tsx index bcd36972..3122c287 100644 --- a/extension/react-app/src/components/HeaderButtonWithText.tsx +++ b/extension/react-app/src/components/HeaderButtonWithText.tsx @@ -1,7 +1,5 @@ import React, { useState } from "react"; -import { Tooltip } from "react-tooltip"; -import styled from "styled-components"; -import { HeaderButton, StyledTooltip, defaultBorderRadius } from "."; +import { HeaderButton, StyledTooltip } from "."; interface HeaderButtonWithTextProps { text: string; diff --git a/extension/react-app/src/components/InfoHover.tsx b/extension/react-app/src/components/InfoHover.tsx new file mode 100644 index 00000000..2cb8ad71 --- /dev/null +++ b/extension/react-app/src/components/InfoHover.tsx @@ -0,0 +1,19 @@ +import { InformationCircleIcon } from "@heroicons/react/24/outline"; +import { StyledTooltip } from "."; + +const InfoHover = ({ msg }: { msg: string }) => { + const id = "info-hover"; + + return ( + <> + + + + ); +}; + +export default InfoHover; diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx index cec3f8e1..c0f0929b 100644 --- a/extension/react-app/src/components/Layout.tsx +++ b/extension/react-app/src/components/Layout.tsx @@ -18,6 +18,7 @@ import { BookOpenIcon, ChatBubbleOvalLeftEllipsisIcon, SparklesIcon, + Cog6ToothIcon, } from "@heroicons/react/24/outline"; import HeaderButtonWithText from "./HeaderButtonWithText"; import { useNavigate } from "react-router-dom"; @@ -193,6 +194,14 @@ const Layout = () => { + { + navigate("/settings"); + }} + text="Settings" + > + + diff --git a/extension/react-app/src/components/ModelSettings.tsx b/extension/react-app/src/components/ModelSettings.tsx new file mode 100644 index 00000000..99200502 --- /dev/null +++ b/extension/react-app/src/components/ModelSettings.tsx @@ -0,0 +1,107 @@ +import styled from "styled-components"; +import { LLM } from "../../../schema/LLM"; +import { + Label, + Select, + TextInput, + defaultBorderRadius, + lightGray, + vscForeground, +} from "."; +import { useState } from "react"; +import { useFormContext } from "react-hook-form"; + +const Div = styled.div<{ dashed: boolean }>` + border: 1px ${(props) => (props.dashed ? "dashed" : "solid")} ${lightGray}; + border-radius: ${defaultBorderRadius}; + padding: 8px; + margin-bottom: 16px; +`; + +type ModelOption = "api_key" | "model" | "context_length"; + +const DefaultModelOptions: { + [key: string]: { [key in ModelOption]?: string }; +} = { + OpenAI: { + api_key: "", + model: "gpt-4", + }, + MaybeProxyOpenAI: { + api_key: "", + model: "gpt-4", + }, + Anthropic: { + api_key: "", + model: "claude-2", + }, + default: { + api_key: "", + model: "gpt-4", + }, +}; + +function ModelSettings(props: { llm: any | undefined; role: string }) { + const [modelOptions, setModelOptions] = useState<{ + [key in ModelOption]?: string; + }>(DefaultModelOptions[props.llm?.class_name || "default"]); + + const { register, setValue, getValues } = useFormContext(); + + return ( +
+ {props.llm ? ( + <> + {props.role}: {props.llm.class_name || "gpt-4"} +
+ {typeof modelOptions.api_key !== undefined && ( + <> + + + + )} + {modelOptions.model && ( + <> + + + + )} + + + ) : ( +
+ Add Model +
+ +
+
+ )} +
+ ); +} + +export default ModelSettings; diff --git a/extension/react-app/src/components/index.ts b/extension/react-app/src/components/index.ts index 6705ceb2..f2e154bc 100644 --- a/extension/react-app/src/components/index.ts +++ b/extension/react-app/src/components/index.ts @@ -40,21 +40,29 @@ export const StyledTooltip = styled(Tooltip)` padding-left: 12px; padding-right: 12px; z-index: 100; + + max-width: 80vw; `; export const TextArea = styled.textarea` - width: 100%; + padding: 8px; + font-family: inherit; border-radius: ${defaultBorderRadius}; - border: none; + margin: 16px auto; + height: auto; + width: calc(100% - 32px); background-color: ${secondaryDark}; - resize: vertical; - - padding: 4px; - caret-color: ${vscForeground}; - color: #{vscForeground}; + color: ${vscForeground}; + z-index: 1; + border: 1px solid transparent; &:focus { - outline: 1px solid ${buttonColor}; + outline: 1px solid ${lightGray}; + border: 1px solid transparent; + } + + &::placeholder { + color: ${lightGray}80; } `; @@ -84,11 +92,33 @@ export const H3 = styled.h3` export const TextInput = styled.input.attrs({ type: "text" })` width: 100%; - padding: 12px 20px; + padding: 8px 12px; + margin: 8px 0; + box-sizing: border-box; + border-radius: ${defaultBorderRadius}; + outline: 1px solid ${lightGray}; + border: none; + background-color: ${vscBackground}; + color: ${vscForeground}; + + &:focus { + background: ${secondaryDark}; + } +`; + +export const Select = styled.select` + padding: 8px 12px; margin: 8px 0; box-sizing: border-box; border-radius: ${defaultBorderRadius}; - border: 2px solid gray; + outline: 1px solid ${lightGray}; + border: none; + background-color: ${vscBackground}; + color: ${vscForeground}; +`; + +export const Label = styled.label` + font-size: 13px; `; const spin = keyframes` diff --git a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts index c9e7def2..f8c11527 100644 --- a/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/AbstractContinueGUIClientProtocol.ts @@ -37,6 +37,16 @@ abstract class AbstractContinueGUIClientProtocol { abstract editStepAtIndex(userInput: string, index: number): void; + abstract setSystemMessage(message: string): void; + + abstract setTemperature(temperature: number): void; + + abstract setModelForRole( + role: string, + model_class: string, + model: string + ): void; + abstract saveContextGroup(title: string, contextItems: ContextItem[]): void; abstract selectContextGroup(id: string): void; diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index b3ac2570..ce9b2a0a 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -133,6 +133,18 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol { }); } + setSystemMessage(message: string): void { + this.messenger?.send("set_system_message", { message }); + } + + setTemperature(temperature: number): void { + this.messenger?.send("set_temperature", { temperature }); + } + + setModelForRole(role: string, model_class: string, model: any): void { + this.messenger?.send("set_model_for_role", { role, model, model_class }); + } + saveContextGroup(title: string, contextItems: ContextItem[]): void { this.messenger?.send("save_context_group", { context_items: contextItems, diff --git a/extension/react-app/src/pages/settings.tsx b/extension/react-app/src/pages/settings.tsx new file mode 100644 index 00000000..8fd91ff5 --- /dev/null +++ b/extension/react-app/src/pages/settings.tsx @@ -0,0 +1,229 @@ +import React, { useContext, useEffect, useState } from "react"; +import { GUIClientContext } from "../App"; +import { useSelector } from "react-redux"; +import { RootStore } from "../redux/store"; +import { useNavigate } from "react-router-dom"; +import { ContinueConfig } from "../../../schema/ContinueConfig"; +import { + Button, + Select, + TextArea, + lightGray, + secondaryDark, +} from "../components"; +import styled from "styled-components"; +import { ArrowLeftIcon } from "@heroicons/react/24/outline"; +import Loader from "../components/Loader"; +import InfoHover from "../components/InfoHover"; +import { FormProvider, useForm } from "react-hook-form"; +import ModelSettings from "../components/ModelSettings"; + +const Hr = styled.hr` + border: 0.5px solid ${lightGray}; +`; + +const CancelButton = styled(Button)` + background-color: transparent; + color: ${lightGray}; + border: 1px solid ${lightGray}; + &:hover { + background-color: ${lightGray}; + color: black; + } +`; + +const SaveButton = styled(Button)` + &:hover { + opacity: 0.8; + } +`; + +const Slider = styled.input.attrs({ type: "range" })` + --webkit-appearance: none; + width: 100%; + background-color: ${secondaryDark}; + outline: none; + border: none; + opacity: 0.7; + -webkit-transition: 0.2s; + transition: opacity 0.2s; + &:hover { + opacity: 1; + } + &::-webkit-slider-runnable-track { + width: 100%; + height: 8px; + cursor: pointer; + background: ${lightGray}; + border-radius: 4px; + } + &::-webkit-slider-thumb { + -webkit-appearance: none; + appearance: none; + width: 8px; + height: 8px; + cursor: pointer; + margin-top: -3px; + } + &::-moz-range-thumb { + width: 8px; + height: 8px; + cursor: pointer; + margin-top: -3px; + } + + &:focus { + outline: none; + border: none; + } +`; +const ALL_MODEL_ROLES = ["default", "small", "medium", "large", "edit", "chat"]; + +function Settings() { + const formMethods = useForm(); + const onSubmit = (data: ContinueConfig) => console.log(data); + + const navigate = useNavigate(); + const client = useContext(GUIClientContext); + const config = useSelector((state: RootStore) => state.serverState.config); + + const submitChanges = () => { + if (!client) return; + + const systemMessage = formMethods.watch("system_message"); + const temperature = formMethods.watch("temperature"); + // const models = formMethods.watch("models"); + + if (systemMessage) client.setSystemMessage(systemMessage); + if (temperature) client.setTemperature(temperature); + + // if (models) { + // for (const role of ALL_MODEL_ROLES) { + // if (models[role]) { + // client.setModelForRole(role, models[role] as string, models[role]); + // } + // } + // } + }; + + const submitAndLeave = () => { + submitChanges(); + navigate("/"); + }; + + return ( + +
+
+
+ +

Settings

+
+ {config ? ( +
+

+ System Message + +

+