summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/poetry.lock58
-rw-r--r--continuedev/pyproject.toml1
-rw-r--r--continuedev/src/continuedev/core/autopilot.py7
-rw-r--r--continuedev/src/continuedev/core/config.py8
-rw-r--r--continuedev/src/continuedev/core/main.py22
-rw-r--r--continuedev/src/continuedev/core/models.py94
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py1
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py75
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py11
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py8
-rw-r--r--continuedev/src/continuedev/server/gui.py43
14 files changed, 288 insertions, 62 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock
index d3140756..aefc7cf9 100644
--- a/continuedev/poetry.lock
+++ b/continuedev/poetry.lock
@@ -173,6 +173,17 @@ test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=
trio = ["trio (>=0.16,<0.22)"]
[[package]]
+name = "appdirs"
+version = "1.4.4"
+description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
+optional = false
+python-versions = "*"
+files = [
+ {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"},
+ {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"},
+]
+
+[[package]]
name = "async-timeout"
version = "4.0.2"
description = "Timeout context manager for asyncio programs"
@@ -213,6 +224,20 @@ files = [
]
[[package]]
+name = "baron"
+version = "0.10.1"
+description = "Full Syntax Tree for python to make writing refactoring code a realist task"
+optional = false
+python-versions = "*"
+files = [
+ {file = "baron-0.10.1-py2.py3-none-any.whl", hash = "sha256:befb33f4b9e832c7cd1e3cf0eafa6dd3cb6ed4cb2544245147c019936f4e0a8a"},
+ {file = "baron-0.10.1.tar.gz", hash = "sha256:af822ad44d4eb425c8516df4239ac4fdba9fdb398ef77e4924cd7c9b4045bc2f"},
+]
+
+[package.dependencies]
+rply = "*"
+
+[[package]]
name = "beautifulsoup4"
version = "4.12.2"
description = "Screen-scraping library"
@@ -1189,6 +1214,23 @@ files = [
cli = ["click (>=5.0)"]
[[package]]
+name = "redbaron"
+version = "0.9.2"
+description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task"
+optional = false
+python-versions = "*"
+files = [
+ {file = "redbaron-0.9.2-py2.py3-none-any.whl", hash = "sha256:d01032b6a848b5521a8d6ef72486315c2880f420956870cdd742e2b5a09b9bab"},
+ {file = "redbaron-0.9.2.tar.gz", hash = "sha256:472d0739ca6b2240bb2278ae428604a75472c9c12e86c6321e8c016139c0132f"},
+]
+
+[package.dependencies]
+baron = ">=0.7"
+
+[package.extras]
+notebook = ["pygments"]
+
+[[package]]
name = "regex"
version = "2023.5.5"
description = "Alternative regular expression module, to replace re."
@@ -1336,6 +1378,20 @@ files = [
]
[[package]]
+name = "rply"
+version = "0.7.8"
+description = "A pure Python Lex/Yacc that works with RPython"
+optional = false
+python-versions = "*"
+files = [
+ {file = "rply-0.7.8-py2.py3-none-any.whl", hash = "sha256:28ffd11d656c48aeb8c508eb382acd6a0bd906662624b34388751732a27807e7"},
+ {file = "rply-0.7.8.tar.gz", hash = "sha256:2a808ac25a4580a9991fc304d64434e299a8fc75760574492f242cbb5bb301c9"},
+]
+
+[package.dependencies]
+appdirs = "*"
+
+[[package]]
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
@@ -1849,4 +1905,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
-content-hash = "5500ea86b06a96f5fe45939500936911e622043a67a3a5c3d02473463ff2fd6c"
+content-hash = "fe4715494ed91c691ec1eb914373a612e75751e6685678e438b73193879de98d"
diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml
index 90ff0572..8cdf1197 100644
--- a/continuedev/pyproject.toml
+++ b/continuedev/pyproject.toml
@@ -31,6 +31,7 @@ socksio = "^1.0.0"
ripgrepy = "^2.0.0"
bs4 = "^0.0.1"
replicate = "^0.11.0"
+redbaron = "^0.9.2"
[tool.poetry.scripts]
typegen = "src.continuedev.models.generate_json_schema:main"
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 7b0661a5..a1b21903 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -5,11 +5,13 @@ import traceback
from functools import cached_property
from typing import Callable, Coroutine, Dict, List, Optional
+import redbaron
from aiohttp import ClientPayloadError
from openai import error as openai_errors
from pydantic import root_validator
from ..libs.util.create_async_task import create_async_task
+from ..libs.util.edit_config import edit_config_property
from ..libs.util.logging import logger
from ..libs.util.paths import getSavedContextGroupsPath
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -158,6 +160,7 @@ class Autopilot(ContinueBaseModel):
if self.context_manager is not None
else [],
session_info=self.session_info,
+ config=self.continue_sdk.config,
saved_context_groups=self._saved_context_groups,
)
self.full_state = full_state
@@ -542,6 +545,10 @@ class Autopilot(ContinueBaseModel):
await self.context_manager.select_context_item(id, query)
await self.update_subscribers()
+ async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron):
+ edit_config_property(key_path, value)
+ await self.update_subscribers()
+
_saved_context_groups: Dict[str, List[ContextItem]] = {}
def _persist_context_groups(self):
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index f5bf81fb..62e9c690 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -14,6 +14,14 @@ class SlashCommand(BaseModel):
step: Type[Step]
params: Optional[Dict] = {}
+ def dict(self, *args, **kwargs):
+ return {
+ "name": self.name,
+ "description": self.description,
+ "params": self.params,
+ "step": self.step.__name__,
+ }
+
class CustomCommand(BaseModel):
name: str
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index e4ee7668..bf098be9 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -277,6 +277,19 @@ class SessionInfo(ContinueBaseModel):
date_created: str
+class ContinueConfig(ContinueBaseModel):
+ system_message: str
+ temperature: float
+
+ class Config:
+ extra = "allow"
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("policy", None)
+ return original_dict
+
+
class FullState(ContinueBaseModel):
"""A full state of the program, including the history"""
@@ -287,19 +300,16 @@ class FullState(ContinueBaseModel):
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
session_info: Optional[SessionInfo] = None
+ config: ContinueConfig
saved_context_groups: Dict[str, List[ContextItem]] = {}
class ContinueSDK:
- pass
+ ...
class Models:
- pass
-
-
-class ContinueConfig:
- pass
+ ...
class Policy(ContinueBaseModel):
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index 52a52b1d..e4610d36 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -1,10 +1,24 @@
-from typing import Any, Optional
+from typing import Optional
from pydantic import BaseModel
from ..libs.llm import LLM
+class ContinueSDK(BaseModel):
+ pass
+
+
+ALL_MODEL_ROLES = [
+ "default",
+ "small",
+ "medium",
+ "large",
+ "edit",
+ "chat",
+]
+
+
class Models(BaseModel):
"""Main class that holds the current model configuration"""
@@ -12,57 +26,47 @@ class Models(BaseModel):
small: Optional[LLM] = None
medium: Optional[LLM] = None
large: Optional[LLM] = None
+ edit: Optional[LLM] = None
+ chat: Optional[LLM] = None
# TODO namespace these away to not confuse readers,
# or split Models into ModelsConfig, which gets turned into Models
- sdk: "ContinueSDK" = None
- system_message: Any = None
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
+ sdk: ContinueSDK = None
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("sdk", None)
+ return original_dict
+
+ @property
+ def all_models(self):
+ models = [getattr(self, role) for role in ALL_MODEL_ROLES]
+ return [model for model in models if model is not None]
+
+ @property
+ def system_message(self) -> Optional[str]:
+ if self.sdk:
+ return self.sdk.config.system_message
+ return None
+
+ def set_system_message(self, msg: str):
+ for model in self.all_models:
+ model.system_message = msg
async def start(self, sdk: "ContinueSDK"):
"""Start each of the LLMs, or fall back to default"""
self.sdk = sdk
- self.system_message = self.sdk.config.system_message
- await sdk.start_model(self.default)
- if self.small:
- await sdk.start_model(self.small)
- else:
- self.small = self.default
-
- if self.medium:
- await sdk.start_model(self.medium)
- else:
- self.medium = self.default
-
- if self.large:
- await sdk.start_model(self.large)
- else:
- self.large = self.default
+
+ for role in ALL_MODEL_ROLES:
+ model = getattr(self, role)
+ if model is None:
+ setattr(self, role, self.default)
+ else:
+ await sdk.start_model(model)
+
+ self.set_system_message(self.system_message)
async def stop(self, sdk: "ContinueSDK"):
"""Stop each LLM (if it's not the default, which is shared)"""
- await self.default.stop()
- if self.small is not self.default:
- await self.small.stop()
- if self.medium is not self.default:
- await self.medium.stop()
- if self.large is not self.default:
- await self.large.stop()
+ for model in self.all_models:
+ await model.stop()
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 4af6b8e2..294e2c8b 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -14,6 +14,14 @@ class LLM(ContinueBaseModel, ABC):
class Config:
arbitrary_types_allowed = True
+ extra = "allow"
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("write_log", None)
+ original_dict["name"] = self.name
+ original_dict["class_name"] = self.__class__.__name__
+ return original_dict
@abstractproperty
def name(self):
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index 65e5db3a..daffe41f 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -16,9 +16,16 @@ class MaybeProxyOpenAI(LLM):
llm: Optional[LLM] = None
+ def update_llm_properties(self):
+ if self.llm is not None:
+ self.llm.system_message = self.system_message
+
@property
def name(self):
- return self.llm.name
+ if self.llm is not None:
+ return self.llm.name
+ else:
+ return None
@property
def context_length(self):
@@ -44,11 +51,13 @@ class MaybeProxyOpenAI(LLM):
async def complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
+ self.update_llm_properties()
return await self.llm.complete(prompt, with_history=with_history, **kwargs)
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
+ self.update_llm_properties()
resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs)
async for item in resp:
yield item
@@ -56,6 +65,7 @@ class MaybeProxyOpenAI(LLM):
async def stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
+ self.update_llm_properties()
resp = self.llm.stream_chat(messages=messages, **kwargs)
async for item in resp:
yield item
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c2d86841..276cc290 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -54,7 +54,6 @@ class OpenAI(LLM):
requires_write_log = True
- system_message: Optional[str] = None
write_log: Optional[Callable[[str], None]] = None
async def start(
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
new file mode 100644
index 00000000..5c070bb4
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -0,0 +1,75 @@
+import threading
+from typing import Dict, List
+
+import redbaron
+
+config_file_path = "/Users/natesesti/.continue/config.py"
+
+
+def load_red():
+ with open(config_file_path, "r") as file:
+ source_code = file.read()
+
+ red = redbaron.RedBaron(source_code)
+ return red
+
+
+def get_config_node(red):
+ for node in red:
+ if node.type == "assignment" and node.target.value == "config":
+ return node
+ else:
+ raise Exception("Config file appears to be improperly formatted")
+
+
+def edit_property(
+ args: redbaron.RedBaron, key_path: List[str], value: redbaron.RedBaron
+):
+ for i in range(len(args)):
+ node = args[i]
+ if node.type != "call_argument":
+ continue
+
+ if node.target.value == key_path[0]:
+ if len(key_path) > 1:
+ edit_property(node.value.value[1].value, key_path[1:], value)
+ else:
+ args[i].value = value
+ return
+
+
+edit_lock = threading.Lock()
+
+
+def edit_config_property(key_path: List[str], value: redbaron.RedBaron):
+ with edit_lock:
+ red = load_red()
+ config = get_config_node(red)
+ config_args = config.value.value[1].value
+ edit_property(config_args, key_path, value)
+
+ with open(config_file_path, "w") as file:
+ file.write(red.dumps())
+
+
+def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
+ args = [f"{key}={value}" for key, value in args.items()]
+ return redbaron.RedBaron(f"{class_name}({', '.join(args)})")[0]
+
+
+def create_string_node(string: str) -> redbaron.RedBaron:
+ return redbaron.RedBaron(f'"{string}"')[0]
+
+
+def create_float_node(float: float) -> redbaron.RedBaron:
+ return redbaron.RedBaron(f"{float}")[0]
+
+
+# Example:
+# edit_config_property(
+# [
+# "models",
+# "default",
+# ],
+# create_obj_node("OpenAI", {"api_key": '""', "model": '"gpt-4"'}),
+# )
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 1c43f0a0..ad727f06 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -2,18 +2,24 @@ import os
from pydantic import schema_json_of
+from ..core.config import ContinueConfig
from ..core.context import ContextItem
from ..core.main import FullState, History, HistoryNode, SessionInfo
+from ..core.models import Models
+from ..libs.llm import LLM
from .filesystem import FileEdit, RangeInFile
from .filesystem_edit import FileEditWithFullContents
-from .main import *
+from .main import Position, Range, Traceback, TracebackFrame
MODELS_TO_GENERATE = (
[Position, Range, Traceback, TracebackFrame]
+ [RangeInFile, FileEdit]
+ [FileEditWithFullContents]
+ [History, HistoryNode, FullState, SessionInfo]
+ + [ContinueConfig]
+ [ContextItem]
+ + [Models]
+ + [LLM]
)
RENAMES = {"ExampleClass": "RenamedName"}
@@ -34,7 +40,10 @@ def main():
try:
json = schema_json_of(model, indent=2, title=title)
except Exception as e:
+ import traceback
+
print(f"Failed to generate json schema for {title}: {e}")
+ traceback.print_exc()
continue # pun intended
with open(f"{SCHEMA_DIR}/{title}.json", "w") as f:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index 7f2ebef1..63548698 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -76,7 +76,7 @@ class SimpleChatStep(Step):
messages = self.messages or await sdk.get_chat_context()
- generator = sdk.models.default.stream_chat(
+ generator = sdk.models.chat.stream_chat(
messages, temperature=sdk.config.temperature
)
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 86569dcb..3de76eaf 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -8,11 +8,7 @@ from typing import Any, Coroutine, List, Union
from pydantic import validator
from ....core.main import ChatMessage, ContinueCustomException, Step
-from ....core.observation import (
- Observation,
- TextObservation,
- UserInputObservation,
-)
+from ....core.observation import Observation, TextObservation, UserInputObservation
from ....libs.llm.ggml import GGML
from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
@@ -246,7 +242,7 @@ class DefaultModelEditCodeStep(Step):
# We don't know here all of the functions being passed in.
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
- model_to_use = sdk.models.default
+ model_to_use = sdk.models.edit
max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 5589284a..bdcaad47 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -10,6 +10,11 @@ from uvicorn.main import Server
from ..core.main import ContextItem
from ..libs.util.create_async_task import create_async_task
+from ..libs.util.edit_config import (
+ create_float_node,
+ create_obj_node,
+ create_string_node,
+)
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
@@ -105,6 +110,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.load_session(data.get("session_id", None))
elif message_type == "edit_step_at_index":
self.edit_step_at_index(data.get("user_input", ""), data["index"])
+ elif message_type == "set_system_message":
+ self.set_system_message(data["message"])
+ elif message_type == "set_temperature":
+ self.set_temperature(float(data["temperature"]))
+ elif message_type == "set_model_for_role":
+ self.set_model_for_role(data["role"], data["model_class"], data["model"])
elif message_type == "save_context_group":
self.save_context_group(
data["title"], [ContextItem(**item) for item in data["context_items"]]
@@ -195,6 +206,38 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
posthog_logger.capture_event("load_session", {"session_id": session_id})
+ def set_system_message(self, message: str):
+ self.session.autopilot.continue_sdk.config.system_message = message
+ self.session.autopilot.continue_sdk.models.set_system_message(message)
+
+ create_async_task(
+ self.session.autopilot.set_config_attr(
+ ["system_message"], create_string_node(message)
+ ),
+ self.on_error,
+ )
+
+ def set_temperature(self, temperature: float):
+ self.session.autopilot.continue_sdk.config.temperature = temperature
+ create_async_task(
+ self.session.autopilot.set_config_attr(
+ ["temperature"], create_float_node(temperature)
+ ),
+ self.on_error,
+ )
+
+ def set_model_for_role(self, role: str, model_class: str, model: Any):
+ prev_model = self.session.autopilot.continue_sdk.models.__getattr__(role)
+ if prev_model is not None:
+ prev_model.update(model)
+ self.session.autopilot.continue_sdk.models.__setattr__(role, model)
+ create_async_task(
+ self.session.autopilot.set_config_attr(
+ ["models", role], create_obj_node(model_class, {**model})
+ ),
+ self.on_error,
+ )
+
def save_context_group(self, title: str, context_items: List[ContextItem]):
create_async_task(
self.session.autopilot.save_context_group(title, context_items),