summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 22:12:44 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 22:12:44 -0700
commit73e1cfbefbf450ab6564aba653e0132843223c7a (patch)
treee7b8aaec84a400e6b1d1c23ab1e703204b20a4d9 /continuedev/src
parentc5d05cec0cafa541c6b00153433864f95beeb56c (diff)
downloadsncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.gz
sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.bz2
sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.zip
templated system messages
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/config.py1
-rw-r--r--continuedev/src/continuedev/core/sdk.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py7
-rw-r--r--continuedev/src/continuedev/libs/util/templating.py39
8 files changed, 62 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 957609c5..91a47c8e 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -85,6 +85,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS
on_traceback: Optional[List[OnTracebackSteps]] = [
OnTracebackSteps(step_name="DefaultOnTracebackStep")]
+ system_message: Optional[str] = None
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index eb60109c..ac57c122 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = {
class Models:
provider_keys: Dict[ModelProvider, str] = {}
model_providers: List[ModelProvider]
+ system_message: str
def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
self.model_providers = model_providers
+ self.system_message = sdk.config.system_message
@classmethod
async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
@@ -53,12 +55,12 @@ class Models:
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
+ return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model)
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
@cached_property
def starcoder(self):
@@ -82,7 +84,7 @@ class Models:
@cached_property
def ggml(self):
- return GGML()
+ return GGML(system_message=self.system_message)
def __model_from_name(self, model_name: str):
if model_name == "starcoder":
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index d3589b70..6007fdb4 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -33,7 +33,7 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+ self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
@@ -50,7 +50,7 @@ class GGML(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
async with aiohttp.ClientSession() as session:
@@ -77,7 +77,7 @@ class GGML(LLM):
async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args
}) as resp:
try:
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 803ba122..7e11fbbe 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -11,9 +11,10 @@ class HuggingFaceInferenceAPI(LLM):
api_key: str
model: str
- def __init__(self, api_key: str, model: str):
+ def __init__(self, api_key: str, model: str, system_message: str = None):
self.api_key = api_key
self.model = model
+ self.system_message = system_message # TODO: Nothing being done with this
def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
"""Return the completion of the text with the given temperature."""
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index f0877d90..d973f19e 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -37,7 +37,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +58,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +69,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
)).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index eab6e441..3ec492f3 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -50,7 +50,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -74,7 +74,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index e1baeca1..1ca98fe6 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,6 +1,7 @@
import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
+from .templating import render_system_message
import tiktoken
aliases = {
@@ -85,13 +86,15 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
for function in functions:
prompt_tokens += count_tokens(model, json.dumps(function))
+ rendered_system_message = render_system_message(system_message)
+
msgs = prune_chat_history(model,
- msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message))
+ msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message))
history = []
if system_message:
history.append({
"role": "system",
- "content": system_message
+ "content": rendered_system_message
})
history += [msg.to_dict(with_functions=functions is not None)
for msg in msgs]
diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py
new file mode 100644
index 00000000..ebfc2e31
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/templating.py
@@ -0,0 +1,39 @@
+import os
+import chevron
+
+
+def get_vars_in_template(template):
+ """
+ Get the variables in a template
+ """
+ return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable']
+
+
+def escape_var(var: str) -> str:
+ """
+ Escape a variable so it can be used in a template
+ """
+ return var.replace(os.path.sep, '').replace('.', '')
+
+
+def render_system_message(system_message: str) -> str:
+ """
+ Render system message with mustache syntax.
+ Right now it only supports rendering absolute file paths as their contents.
+ """
+ vars = get_vars_in_template(system_message)
+
+ args = {}
+ for var in vars:
+ if var.startswith(os.path.sep):
+ # Escape vars which are filenames, because mustache doesn't allow / in variable names
+ escaped_var = escape_var(var)
+ system_message = system_message.replace(
+ var, escaped_var)
+
+ if os.path.exists(var):
+ args[escaped_var] = open(var, 'r').read()
+ else:
+ args[escaped_var] = ''
+
+ return chevron.render(system_message, args)