diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 |
commit | 73e1cfbefbf450ab6564aba653e0132843223c7a (patch) | |
tree | e7b8aaec84a400e6b1d1c23ab1e703204b20a4d9 /continuedev/src | |
parent | c5d05cec0cafa541c6b00153433864f95beeb56c (diff) | |
download | sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.gz sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.bz2 sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.zip |
templated system messages
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/templating.py | 39 |
8 files changed, 62 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 957609c5..91a47c8e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -85,6 +85,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS on_traceback: Optional[List[OnTracebackSteps]] = [ OnTracebackSteps(step_name="DefaultOnTracebackStep")] + system_message: Optional[str] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index eb60109c..ac57c122 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = { class Models: provider_keys: Dict[ModelProvider, str] = {} model_providers: List[ModelProvider] + system_message: str def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk self.model_providers = model_providers + self.system_message = sdk.config.system_message @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": @@ -53,12 +55,12 @@ class Models: def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) + return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message) def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model) + return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) @cached_property def starcoder(self): @@ -82,7 +84,7 @@ class Models: @cached_property def ggml(self): - return GGML() + return GGML(system_message=self.system_message) def __model_from_name(self, model_name: str): if model_name == "starcoder": diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index d3589b70..6007fdb4 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -33,7 +33,7 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ @@ -50,7 +50,7 @@ class GGML(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, args["max_tokens"], None, functions=args.get("functions", None)) + self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True async with aiohttp.ClientSession() as session: @@ -77,7 +77,7 @@ class GGML(LLM): async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args }) as resp: try: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 803ba122..7e11fbbe 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -11,9 +11,10 @@ class HuggingFaceInferenceAPI(LLM): api_key: str model: str - def __init__(self, api_key: str, model: str): + def __init__(self, api_key: str, model: str, system_message: str = None): self.api_key = api_key self.model = model + self.system_message = system_message # TODO: Nothing being done with this def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0877d90..d973f19e 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -37,7 +37,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, ): if "content" in chunk.choices[0].delta: @@ -58,7 +58,7 @@ class OpenAI(LLM): async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), + args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message), **args, ): yield chunk.choices[0].delta @@ -69,7 +69,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, )).choices[0].message.content else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index eab6e441..3ec492f3 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), "unique_id": self.unique_id, **args }) as resp: @@ -50,7 +50,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) + self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -74,7 +74,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index e1baeca1..1ca98fe6 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,6 +1,7 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage +from .templating import render_system_message import tiktoken aliases = { @@ -85,13 +86,15 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, for function in functions: prompt_tokens += count_tokens(model, json.dumps(function)) + rendered_system_message = render_system_message(system_message) + msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message)) + msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message)) history = [] if system_message: history.append({ "role": "system", - "content": system_message + "content": rendered_system_message }) history += [msg.to_dict(with_functions=functions is not None) for msg in msgs] diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py new file mode 100644 index 00000000..ebfc2e31 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -0,0 +1,39 @@ +import os +import chevron + + +def get_vars_in_template(template): + """ + Get the variables in a template + """ + return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + + +def escape_var(var: str) -> str: + """ + Escape a variable so it can be used in a template + """ + return var.replace(os.path.sep, '').replace('.', '') + + +def render_system_message(system_message: str) -> str: + """ + Render system message with mustache syntax. + Right now it only supports rendering absolute file paths as their contents. + """ + vars = get_vars_in_template(system_message) + + args = {} + for var in vars: + if var.startswith(os.path.sep): + # Escape vars which are filenames, because mustache doesn't allow / in variable names + escaped_var = escape_var(var) + system_message = system_message.replace( + var, escaped_var) + + if os.path.exists(var): + args[escaped_var] = open(var, 'r').read() + else: + args[escaped_var] = '' + + return chevron.render(system_message, args) |