diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-08-22 13:12:58 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-22 13:12:58 -0700 |
commit | 5eec484dc79bb56dabf9a56af0dbe6bc95227d39 (patch) | |
tree | 07edc01475cfe0ba69372f537d36aa3294680b7d /continuedev/src/continuedev/core | |
parent | b6435e1e479edb1e4f049098dc8522e944317f2a (diff) | |
download | sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.tar.gz sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.tar.bz2 sncontinue-5eec484dc79bb56dabf9a56af0dbe6bc95227d39.zip |
Config UI (#399)
* feat: :sparkles: UI for config!
* feat: :sparkles: (latent) edit models in settings
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 22 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 94 |
4 files changed, 80 insertions, 51 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 7b0661a5..a1b21903 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -5,11 +5,13 @@ import traceback from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional +import redbaron from aiohttp import ClientPayloadError from openai import error as openai_errors from pydantic import root_validator from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import edit_config_property from ..libs.util.logging import logger from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue @@ -158,6 +160,7 @@ class Autopilot(ContinueBaseModel): if self.context_manager is not None else [], session_info=self.session_info, + config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, ) self.full_state = full_state @@ -542,6 +545,10 @@ class Autopilot(ContinueBaseModel): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): + edit_config_property(key_path, value) + await self.update_subscribers() + _saved_context_groups: Dict[str, List[ContextItem]] = {} def _persist_context_groups(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index f5bf81fb..62e9c690 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -14,6 +14,14 @@ class SlashCommand(BaseModel): step: Type[Step] params: Optional[Dict] = {} + def dict(self, *args, **kwargs): + return { + "name": self.name, + "description": self.description, + "params": self.params, + "step": self.step.__name__, + } + class CustomCommand(BaseModel): name: str diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index e4ee7668..bf098be9 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -277,6 +277,19 @@ class SessionInfo(ContinueBaseModel): date_created: str +class ContinueConfig(ContinueBaseModel): + system_message: str + temperature: float + + class Config: + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("policy", None) + return original_dict + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" @@ -287,19 +300,16 @@ class FullState(ContinueBaseModel): adding_highlighted_code: bool selected_context_items: List[ContextItem] session_info: Optional[SessionInfo] = None + config: ContinueConfig saved_context_groups: Dict[str, List[ContextItem]] = {} class ContinueSDK: - pass + ... class Models: - pass - - -class ContinueConfig: - pass + ... class Policy(ContinueBaseModel): diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 52a52b1d..e4610d36 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,24 @@ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel from ..libs.llm import LLM +class ContinueSDK(BaseModel): + pass + + +ALL_MODEL_ROLES = [ + "default", + "small", + "medium", + "large", + "edit", + "chat", +] + + class Models(BaseModel): """Main class that holds the current model configuration""" @@ -12,57 +26,47 @@ class Models(BaseModel): small: Optional[LLM] = None medium: Optional[LLM] = None large: Optional[LLM] = None + edit: Optional[LLM] = None + chat: Optional[LLM] = None # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models - sdk: "ContinueSDK" = None - system_message: Any = None - - """ - Better to have sdk.llm.stream_chat(messages, model="claude-2"). - Then you also don't care that it' async. - And it's easier to add more models. - And intermediate shared code is easier to add. - And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" - PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. - Easy to reason about, can place anywhere. - And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. - This can all happen inside of Models? - - class Prompt: - def __init__(self, ...info): - '''take whatever info is needed to describe the prompt''' - - def to_string(self, model: str) -> str: - '''depending on the model, return the single prompt string''' - """ + sdk: ContinueSDK = None + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("sdk", None) + return original_dict + + @property + def all_models(self): + models = [getattr(self, role) for role in ALL_MODEL_ROLES] + return [model for model in models if model is not None] + + @property + def system_message(self) -> Optional[str]: + if self.sdk: + return self.sdk.config.system_message + return None + + def set_system_message(self, msg: str): + for model in self.all_models: + model.system_message = msg async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" self.sdk = sdk - self.system_message = self.sdk.config.system_message - await sdk.start_model(self.default) - if self.small: - await sdk.start_model(self.small) - else: - self.small = self.default - - if self.medium: - await sdk.start_model(self.medium) - else: - self.medium = self.default - - if self.large: - await sdk.start_model(self.large) - else: - self.large = self.default + + for role in ALL_MODEL_ROLES: + model = getattr(self, role) + if model is None: + setattr(self, role, self.default) + else: + await sdk.start_model(model) + + self.set_system_message(self.system_message) async def stop(self, sdk: "ContinueSDK"): """Stop each LLM (if it's not the default, which is shared)""" - await self.default.stop() - if self.small is not self.default: - await self.small.stop() - if self.medium is not self.default: - await self.medium.stop() - if self.large is not self.default: - await self.large.stop() + for model in self.all_models: + await model.stop() |