diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/poetry.lock | 58 | ||||
-rw-r--r-- | continuedev/pyproject.toml | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 22 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 94 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 75 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/generate_json_schema.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/chat.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 43 |
14 files changed, 288 insertions, 62 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index d3140756..aefc7cf9 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -173,6 +173,17 @@ test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>= trio = ["trio (>=0.16,<0.22)"] [[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = "*" +files = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] + +[[package]] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" @@ -213,6 +224,20 @@ files = [ ] [[package]] +name = "baron" +version = "0.10.1" +description = "Full Syntax Tree for python to make writing refactoring code a realist task" +optional = false +python-versions = "*" +files = [ + {file = "baron-0.10.1-py2.py3-none-any.whl", hash = "sha256:befb33f4b9e832c7cd1e3cf0eafa6dd3cb6ed4cb2544245147c019936f4e0a8a"}, + {file = "baron-0.10.1.tar.gz", hash = "sha256:af822ad44d4eb425c8516df4239ac4fdba9fdb398ef77e4924cd7c9b4045bc2f"}, +] + +[package.dependencies] +rply = "*" + +[[package]] name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" @@ -1189,6 +1214,23 @@ files = [ cli = ["click (>=5.0)"] [[package]] +name = "redbaron" +version = "0.9.2" +description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task" +optional = false +python-versions = "*" +files = [ + {file = "redbaron-0.9.2-py2.py3-none-any.whl", hash = "sha256:d01032b6a848b5521a8d6ef72486315c2880f420956870cdd742e2b5a09b9bab"}, + {file = "redbaron-0.9.2.tar.gz", hash = "sha256:472d0739ca6b2240bb2278ae428604a75472c9c12e86c6321e8c016139c0132f"}, +] + +[package.dependencies] +baron = ">=0.7" + +[package.extras] +notebook = ["pygments"] + +[[package]] name = "regex" version = "2023.5.5" description = "Alternative regular expression module, to replace re." @@ -1336,6 +1378,20 @@ files = [ ] [[package]] +name = "rply" +version = "0.7.8" +description = "A pure Python Lex/Yacc that works with RPython" +optional = false +python-versions = "*" +files = [ + {file = "rply-0.7.8-py2.py3-none-any.whl", hash = "sha256:28ffd11d656c48aeb8c508eb382acd6a0bd906662624b34388751732a27807e7"}, + {file = "rply-0.7.8.tar.gz", hash = "sha256:2a808ac25a4580a9991fc304d64434e299a8fc75760574492f242cbb5bb301c9"}, +] + +[package.dependencies] +appdirs = "*" + +[[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" @@ -1849,4 +1905,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "5500ea86b06a96f5fe45939500936911e622043a67a3a5c3d02473463ff2fd6c" +content-hash = "fe4715494ed91c691ec1eb914373a612e75751e6685678e438b73193879de98d" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 90ff0572..8cdf1197 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -31,6 +31,7 @@ socksio = "^1.0.0" ripgrepy = "^2.0.0" bs4 = "^0.0.1" replicate = "^0.11.0" +redbaron = "^0.9.2" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 7b0661a5..a1b21903 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -5,11 +5,13 @@ import traceback from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional +import redbaron from aiohttp import ClientPayloadError from openai import error as openai_errors from pydantic import root_validator from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import edit_config_property from ..libs.util.logging import logger from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue @@ -158,6 +160,7 @@ class Autopilot(ContinueBaseModel): if self.context_manager is not None else [], session_info=self.session_info, + config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, ) self.full_state = full_state @@ -542,6 +545,10 @@ class Autopilot(ContinueBaseModel): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): + edit_config_property(key_path, value) + await self.update_subscribers() + _saved_context_groups: Dict[str, List[ContextItem]] = {} def _persist_context_groups(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index f5bf81fb..62e9c690 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -14,6 +14,14 @@ class SlashCommand(BaseModel): step: Type[Step] params: Optional[Dict] = {} + def dict(self, *args, **kwargs): + return { + "name": self.name, + "description": self.description, + "params": self.params, + "step": self.step.__name__, + } + class CustomCommand(BaseModel): name: str diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index e4ee7668..bf098be9 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -277,6 +277,19 @@ class SessionInfo(ContinueBaseModel): date_created: str +class ContinueConfig(ContinueBaseModel): + system_message: str + temperature: float + + class Config: + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("policy", None) + return original_dict + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" @@ -287,19 +300,16 @@ class FullState(ContinueBaseModel): adding_highlighted_code: bool selected_context_items: List[ContextItem] session_info: Optional[SessionInfo] = None + config: ContinueConfig saved_context_groups: Dict[str, List[ContextItem]] = {} class ContinueSDK: - pass + ... class Models: - pass - - -class ContinueConfig: - pass + ... class Policy(ContinueBaseModel): diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 52a52b1d..e4610d36 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,24 @@ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel from ..libs.llm import LLM +class ContinueSDK(BaseModel): + pass + + +ALL_MODEL_ROLES = [ + "default", + "small", + "medium", + "large", + "edit", + "chat", +] + + class Models(BaseModel): """Main class that holds the current model configuration""" @@ -12,57 +26,47 @@ class Models(BaseModel): small: Optional[LLM] = None medium: Optional[LLM] = None large: Optional[LLM] = None + edit: Optional[LLM] = None + chat: Optional[LLM] = None # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models - sdk: "ContinueSDK" = None - system_message: Any = None - - """ - Better to have sdk.llm.stream_chat(messages, model="claude-2"). - Then you also don't care that it' async. - And it's easier to add more models. - And intermediate shared code is easier to add. - And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" - PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. - Easy to reason about, can place anywhere. - And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. - This can all happen inside of Models? - - class Prompt: - def __init__(self, ...info): - '''take whatever info is needed to describe the prompt''' - - def to_string(self, model: str) -> str: - '''depending on the model, return the single prompt string''' - """ + sdk: ContinueSDK = None + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("sdk", None) + return original_dict + + @property + def all_models(self): + models = [getattr(self, role) for role in ALL_MODEL_ROLES] + return [model for model in models if model is not None] + + @property + def system_message(self) -> Optional[str]: + if self.sdk: + return self.sdk.config.system_message + return None + + def set_system_message(self, msg: str): + for model in self.all_models: + model.system_message = msg async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" self.sdk = sdk - self.system_message = self.sdk.config.system_message - await sdk.start_model(self.default) - if self.small: - await sdk.start_model(self.small) - else: - self.small = self.default - - if self.medium: - await sdk.start_model(self.medium) - else: - self.medium = self.default - - if self.large: - await sdk.start_model(self.large) - else: - self.large = self.default + + for role in ALL_MODEL_ROLES: + model = getattr(self, role) + if model is None: + setattr(self, role, self.default) + else: + await sdk.start_model(model) + + self.set_system_message(self.system_message) async def stop(self, sdk: "ContinueSDK"): """Stop each LLM (if it's not the default, which is shared)""" - await self.default.stop() - if self.small is not self.default: - await self.small.stop() - if self.medium is not self.default: - await self.medium.stop() - if self.large is not self.default: - await self.large.stop() + for model in self.all_models: + await model.stop() diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4af6b8e2..294e2c8b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -14,6 +14,14 @@ class LLM(ContinueBaseModel, ABC): class Config: arbitrary_types_allowed = True + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("write_log", None) + original_dict["name"] = self.name + original_dict["class_name"] = self.__class__.__name__ + return original_dict @abstractproperty def name(self): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index 65e5db3a..daffe41f 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -16,9 +16,16 @@ class MaybeProxyOpenAI(LLM): llm: Optional[LLM] = None + def update_llm_properties(self): + if self.llm is not None: + self.llm.system_message = self.system_message + @property def name(self): - return self.llm.name + if self.llm is not None: + return self.llm.name + else: + return None @property def context_length(self): @@ -44,11 +51,13 @@ class MaybeProxyOpenAI(LLM): async def complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: + self.update_llm_properties() return await self.llm.complete(prompt, with_history=with_history, **kwargs) async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: + self.update_llm_properties() resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs) async for item in resp: yield item @@ -56,6 +65,7 @@ class MaybeProxyOpenAI(LLM): async def stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: + self.update_llm_properties() resp = self.llm.stream_chat(messages=messages, **kwargs) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c2d86841..276cc290 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -54,7 +54,6 @@ class OpenAI(LLM): requires_write_log = True - system_message: Optional[str] = None write_log: Optional[Callable[[str], None]] = None async def start( diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py new file mode 100644 index 00000000..5c070bb4 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -0,0 +1,75 @@ +import threading +from typing import Dict, List + +import redbaron + +config_file_path = "/Users/natesesti/.continue/config.py" + + +def load_red(): + with open(config_file_path, "r") as file: + source_code = file.read() + + red = redbaron.RedBaron(source_code) + return red + + +def get_config_node(red): + for node in red: + if node.type == "assignment" and node.target.value == "config": + return node + else: + raise Exception("Config file appears to be improperly formatted") + + +def edit_property( + args: redbaron.RedBaron, key_path: List[str], value: redbaron.RedBaron +): + for i in range(len(args)): + node = args[i] + if node.type != "call_argument": + continue + + if node.target.value == key_path[0]: + if len(key_path) > 1: + edit_property(node.value.value[1].value, key_path[1:], value) + else: + args[i].value = value + return + + +edit_lock = threading.Lock() + + +def edit_config_property(key_path: List[str], value: redbaron.RedBaron): + with edit_lock: + red = load_red() + config = get_config_node(red) + config_args = config.value.value[1].value + edit_property(config_args, key_path, value) + + with open(config_file_path, "w") as file: + file.write(red.dumps()) + + +def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: + args = [f"{key}={value}" for key, value in args.items()] + return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0] + + +def create_string_node(string: str) -> redbaron.RedBaron: + return redbaron.RedBaron(f'"{string}"')[0] + + +def create_float_node(float: float) -> redbaron.RedBaron: + return redbaron.RedBaron(f"{float}")[0] + + +# Example: +# edit_config_property( +# [ +# "models", +# "default", +# ], +# create_obj_node("OpenAI", {"api_key": '""', "model": '"gpt-4"'}), +# ) diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 1c43f0a0..ad727f06 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -2,18 +2,24 @@ import os from pydantic import schema_json_of +from ..core.config import ContinueConfig from ..core.context import ContextItem from ..core.main import FullState, History, HistoryNode, SessionInfo +from ..core.models import Models +from ..libs.llm import LLM from .filesystem import FileEdit, RangeInFile from .filesystem_edit import FileEditWithFullContents -from .main import * +from .main import Position, Range, Traceback, TracebackFrame MODELS_TO_GENERATE = ( [Position, Range, Traceback, TracebackFrame] + [RangeInFile, FileEdit] + [FileEditWithFullContents] + [History, HistoryNode, FullState, SessionInfo] + + [ContinueConfig] + [ContextItem] + + [Models] + + [LLM] ) RENAMES = {"ExampleClass": "RenamedName"} @@ -34,7 +40,10 @@ def main(): try: json = schema_json_of(model, indent=2, title=title) except Exception as e: + import traceback + print(f"Failed to generate json schema for {title}: {e}") + traceback.print_exc() continue # pun intended with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 7f2ebef1..63548698 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -76,7 +76,7 @@ class SimpleChatStep(Step): messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.default.stream_chat( + generator = sdk.models.chat.stream_chat( messages, temperature=sdk.config.temperature ) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 86569dcb..3de76eaf 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -8,11 +8,7 @@ from typing import Any, Coroutine, List, Union from pydantic import validator from ....core.main import ChatMessage, ContinueCustomException, Step -from ....core.observation import ( - Observation, - TextObservation, - UserInputObservation, -) +from ....core.observation import Observation, TextObservation, UserInputObservation from ....libs.llm.ggml import GGML from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS @@ -246,7 +242,7 @@ class DefaultModelEditCodeStep(Step): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.default + model_to_use = sdk.models.edit max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 5589284a..bdcaad47 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -10,6 +10,11 @@ from uvicorn.main import Server from ..core.main import ContextItem from ..libs.util.create_async_task import create_async_task +from ..libs.util.edit_config import ( + create_float_node, + create_obj_node, + create_string_node, +) from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger @@ -105,6 +110,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": self.edit_step_at_index(data.get("user_input", ""), data["index"]) + elif message_type == "set_system_message": + self.set_system_message(data["message"]) + elif message_type == "set_temperature": + self.set_temperature(float(data["temperature"])) + elif message_type == "set_model_for_role": + self.set_model_for_role(data["role"], data["model_class"], data["model"]) elif message_type == "save_context_group": self.save_context_group( data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -195,6 +206,38 @@ class GUIProtocolServer(AbstractGUIProtocolServer): posthog_logger.capture_event("load_session", {"session_id": session_id}) + def set_system_message(self, message: str): + self.session.autopilot.continue_sdk.config.system_message = message + self.session.autopilot.continue_sdk.models.set_system_message(message) + + create_async_task( + self.session.autopilot.set_config_attr( + ["system_message"], create_string_node(message) + ), + self.on_error, + ) + + def set_temperature(self, temperature: float): + self.session.autopilot.continue_sdk.config.temperature = temperature + create_async_task( + self.session.autopilot.set_config_attr( + ["temperature"], create_float_node(temperature) + ), + self.on_error, + ) + + def set_model_for_role(self, role: str, model_class: str, model: Any): + prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role) + if prev_model is not None: + prev_model.update(model) + self.session.autopilot.continue_sdk.models.__setattr__(role, model) + create_async_task( + self.session.autopilot.set_config_attr( + ["models", role], create_obj_node(model_class, {**model}) + ), + self.on_error, + ) + def save_context_group(self, title: str, context_items: List[ContextItem]): create_async_task( self.session.autopilot.save_context_group(title, context_items), |