diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 | 
| commit | 73e1cfbefbf450ab6564aba653e0132843223c7a (patch) | |
| tree | e7b8aaec84a400e6b1d1c23ab1e703204b20a4d9 /continuedev | |
| parent | c5d05cec0cafa541c6b00153433864f95beeb56c (diff) | |
| download | sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.gz sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.bz2 sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.zip | |
templated system messages
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/poetry.lock | 65 | ||||
| -rw-r--r-- | continuedev/pyproject.toml | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/templating.py | 39 | 
10 files changed, 127 insertions, 19 deletions
| diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index a49a570f..625aabc9 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -298,6 +298,18 @@ files = [  ]  [[package]] +name = "chevron" +version = "0.14.0" +description = "Mustache templating language renderer" +category = "main" +optional = false +python-versions = "*" +files = [ +    {file = "chevron-0.14.0-py3-none-any.whl", hash = "sha256:fbf996a709f8da2e745ef763f482ce2d311aa817d287593a5b990d6d6e4f0443"}, +    {file = "chevron-0.14.0.tar.gz", hash = "sha256:87613aafdf6d77b6a90ff073165a61ae5086e21ad49057aa0e53681601800ebf"}, +] + +[[package]]  name = "click"  version = "8.1.3"  description = "Composable command line interface toolkit" @@ -601,6 +613,25 @@ files = [  ]  [[package]] +name = "importlib-resources" +version = "6.0.0" +description = "Read resources from Python packages" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ +    {file = "importlib_resources-6.0.0-py3-none-any.whl", hash = "sha256:d952faee11004c045f785bb5636e8f885bed30dc3c940d5d42798a2a4541c185"}, +    {file = "importlib_resources-6.0.0.tar.gz", hash = "sha256:4cf94875a8368bd89531a756df9a9ebe1f150e0f885030b461237bc7f2d905f2"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + +[[package]]  name = "jsonref"  version = "1.1.0"  description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python." @@ -626,6 +657,8 @@ files = [  [package.dependencies]  attrs = ">=17.4.0" +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}  pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"  [package.extras] @@ -1025,6 +1058,18 @@ test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)"  xml = ["lxml (>=4.6.3)"]  [[package]] +name = "pkgutil-resolve-name" +version = "1.3.10" +description = "Resolve a name to an object." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ +    {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, +    {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, +] + +[[package]]  name = "posthog"  version = "3.0.1"  description = "Integrate PostHog into any python application." @@ -1818,7 +1863,23 @@ files = [  idna = ">=2.0"  multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.16.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ +    {file = "zipp-3.16.2-py3-none-any.whl", hash = "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0"}, +    {file = "zipp-3.16.2.tar.gz", hash = "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +  [metadata]  lock-version = "2.0" -python-versions = "^3.9" -content-hash = "3ba2a7278fda36a059d76e227be94b0cb5e2efc9396b47a9642b916680214d9f" +python-versions = "^3.8.1" +content-hash = "82510deb9f4afb5bc38db0dfd88ad88005fa0b6221c24e8c1700c006360f3f88" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 6727e29a..3077de1c 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Nate Sesti <sestinj@gmail.com>"]  readme = "README.md"  [tool.poetry.dependencies] -python = "^3.8" +python = "^3.8.1"  diff-match-patch = "^20230430"  fastapi = "^0.95.1"  typer = "^0.7.0" @@ -24,6 +24,7 @@ tiktoken = "^0.4.0"  jsonref = "^1.1.0"  jsonschema = "^4.17.3"  directory-tree = "^0.0.3.1" +chevron = "^0.14.0"  [tool.poetry.scripts]  typegen = "src.continuedev.models.generate_json_schema:main"  diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 957609c5..91a47c8e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -85,6 +85,7 @@ class ContinueConfig(BaseModel):      slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS      on_traceback: Optional[List[OnTracebackSteps]] = [          OnTracebackSteps(step_name="DefaultOnTracebackStep")] +    system_message: Optional[str] = None      # Want to force these to be the slash commands for now      @validator('slash_commands', pre=True) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index eb60109c..ac57c122 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = {  class Models:      provider_keys: Dict[ModelProvider, str] = {}      model_providers: List[ModelProvider] +    system_message: str      def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):          self.sdk = sdk          self.model_providers = model_providers +        self.system_message = sdk.config.system_message      @classmethod      async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": @@ -53,12 +55,12 @@ class Models:      def __load_openai_model(self, model: str) -> OpenAI:          api_key = self.provider_keys["openai"]          if api_key == "": -            return ProxyServer(self.sdk.ide.unique_id, model) -        return OpenAI(api_key=api_key, default_model=model) +            return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) +        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message)      def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:          api_key = self.provider_keys["hf_inference_api"] -        return HuggingFaceInferenceAPI(api_key=api_key, model=model) +        return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)      @cached_property      def starcoder(self): @@ -82,7 +84,7 @@ class Models:      @cached_property      def ggml(self): -        return GGML() +        return GGML(system_message=self.system_message)      def __model_from_name(self, model_name: str):          if model_name == "starcoder": diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index d3589b70..6007fdb4 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -33,7 +33,7 @@ class GGML(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) +            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)          async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/completions", json={ @@ -50,7 +50,7 @@ class GGML(LLM):      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None)) +            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          args["stream"] = True          async with aiohttp.ClientSession() as session: @@ -77,7 +77,7 @@ class GGML(LLM):          async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), +                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),                  **args              }) as resp:                  try: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 803ba122..7e11fbbe 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -11,9 +11,10 @@ class HuggingFaceInferenceAPI(LLM):      api_key: str      model: str -    def __init__(self, api_key: str, model: str): +    def __init__(self, api_key: str, model: str, system_message: str = None):          self.api_key = api_key          self.model = model +        self.system_message = system_message  # TODO: Nothing being done with this      def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):          """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0877d90..d973f19e 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -37,7 +37,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              async for chunk in await openai.ChatCompletion.acreate(                  messages=compile_chat_messages( -                    args["model"], with_history, args["max_tokens"], prompt, functions=None), +                    args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),                  **args,              ):                  if "content" in chunk.choices[0].delta: @@ -58,7 +58,7 @@ class OpenAI(LLM):          async for chunk in await openai.ChatCompletion.acreate(              messages=compile_chat_messages( -                args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), +                args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),              **args,          ):              yield chunk.choices[0].delta @@ -69,7 +69,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              resp = (await openai.ChatCompletion.acreate(                  messages=compile_chat_messages( -                    args["model"], with_history, args["max_tokens"], prompt, functions=None), +                    args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),                  **args,              )).choices[0].message.content          else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index eab6e441..3ec492f3 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM):          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), +                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),                  "unique_id": self.unique_id,                  **args              }) as resp: @@ -50,7 +50,7 @@ class ProxyServer(LLM):      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) +            self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -74,7 +74,7 @@ class ProxyServer(LLM):      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) +            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index e1baeca1..1ca98fe6 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,6 +1,7 @@  import json  from typing import Dict, List, Union  from ...core.main import ChatMessage +from .templating import render_system_message  import tiktoken  aliases = { @@ -85,13 +86,15 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,          for function in functions:              prompt_tokens += count_tokens(model, json.dumps(function)) +    rendered_system_message = render_system_message(system_message) +      msgs = prune_chat_history(model, -                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message)) +                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message))      history = []      if system_message:          history.append({              "role": "system", -            "content": system_message +            "content": rendered_system_message          })      history += [msg.to_dict(with_functions=functions is not None)                  for msg in msgs] diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py new file mode 100644 index 00000000..ebfc2e31 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -0,0 +1,39 @@ +import os +import chevron + + +def get_vars_in_template(template): +    """ +    Get the variables in a template +    """ +    return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + + +def escape_var(var: str) -> str: +    """ +    Escape a variable so it can be used in a template +    """ +    return var.replace(os.path.sep, '').replace('.', '') + + +def render_system_message(system_message: str) -> str: +    """ +    Render system message with mustache syntax. +    Right now it only supports rendering absolute file paths as their contents. +    """ +    vars = get_vars_in_template(system_message) + +    args = {} +    for var in vars: +        if var.startswith(os.path.sep): +            # Escape vars which are filenames, because mustache doesn't allow / in variable names +            escaped_var = escape_var(var) +            system_message = system_message.replace( +                var, escaped_var) + +            if os.path.exists(var): +                args[escaped_var] = open(var, 'r').read() +            else: +                args[escaped_var] = '' + +    return chevron.render(system_message, args) | 
