diff options
Diffstat (limited to 'server/continuedev/libs/llm/prompts/chat.py')
-rw-r--r-- | server/continuedev/libs/llm/prompts/chat.py | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/prompts/chat.py b/server/continuedev/libs/llm/prompts/chat.py new file mode 100644 index 00000000..036f1b1a --- /dev/null +++ b/server/continuedev/libs/llm/prompts/chat.py @@ -0,0 +1,174 @@ +from textwrap import dedent +from typing import Dict, List + +from anthropic import AI_PROMPT, HUMAN_PROMPT + + +def anthropic_template_messages(messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if ( + len(messages) > 0 + and messages[0]["role"] != "user" + and messages[0]["role"] != "system" + ): + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + +def template_alpaca_messages(msgs: List[Dict[str, str]]) -> str: + prompt = "" + + if msgs[0]["role"] == "system": + prompt += f"{msgs[0]['content']}\n" + msgs.pop(0) + + for msg in msgs: + prompt += "### Instruction:\n" if msg["role"] == "user" else "### Response:\n" + prompt += f"{msg['content']}\n" + + prompt += "### Response:\n" + + return prompt + + +def raw_input_template(msgs: List[Dict[str, str]]) -> str: + return msgs[-1]["content"] + + +SQL_CODER_DEFAULT_SCHEMA = """\ +CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, -- Unique ID for each product + name VARCHAR(50), -- Name of the product + price DECIMAL(10,2), -- Price of each unit of the product + quantity INTEGER -- Current quantity in stock +); + +CREATE TABLE customers ( + customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer + name VARCHAR(50), -- Name of the customer + address VARCHAR(100) -- Mailing address of the customer +); + +CREATE TABLE salespeople ( + salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson + name VARCHAR(50), -- Name of the salesperson + region VARCHAR(50) -- Geographic sales region +); + +CREATE TABLE sales ( + sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale + product_id INTEGER, -- ID of product sold + customer_id INTEGER, -- ID of customer who made purchase + salesperson_id INTEGER, -- ID of salesperson who made the sale + sale_date DATE, -- Date the sale occurred + quantity INTEGER -- Quantity of product sold +); + +CREATE TABLE product_suppliers ( + supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier + product_id INTEGER, -- Product ID supplied + supply_price DECIMAL(10,2) -- Unit price charged by supplier +); + +-- sales.product_id can be joined with products.product_id +-- sales.customer_id can be joined with customers.customer_id +-- sales.salesperson_id can be joined with salespeople.salesperson_id +-- product_suppliers.product_id can be joined with products.product_id +""" + + +def _sqlcoder_template_messages( + msgs: List[Dict[str, str]], schema: str = SQL_CODER_DEFAULT_SCHEMA +) -> str: + question = msgs[-1]["content"] + return f"""\ +Your task is to convert a question into a SQL query, given a Postgres database schema. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float + +### Input: +Generate a SQL query that answers the question `{question}`. +This query will run on a database whose schema is represented in this string: +{schema} + +### Response: +Based on your instructions, here is the SQL query I have generated to answer the question `{question}`: +```sql +""" + + +def sqlcoder_template_messages(schema: str = SQL_CODER_DEFAULT_SCHEMA): + if schema == "<MY_DATABASE_SCHEMA>" or schema == "": + schema = SQL_CODER_DEFAULT_SCHEMA + + def fn(msgs): + return _sqlcoder_template_messages(msgs, schema=schema) + + fn.__name__ = "sqlcoder_template_messages" + return fn + + +def llama2_template_messages(msgs: List[Dict[str, str]]) -> str: + if len(msgs) == 0: + return "" + + if msgs[0]["role"] == "assistant": + # These models aren't trained to handle assistant message coming first, + # and typically these are just introduction messages from Continue + msgs.pop(0) + + prompt = "" + has_system = msgs[0]["role"] == "system" + + if has_system and msgs[0]["content"].strip() == "": + has_system = False + msgs = msgs[1:] + + if has_system: + system_message = dedent( + f"""\ + <<SYS>> + {msgs[0]["content"]} + <</SYS>> + + """ + ) + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" + else: + prompt += msgs[i]["content"] + " " + + return prompt + + +def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str: + return f"[INST] {msgs[-1]['content']}\n[/INST]" + + +def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str: + return f" {msgs[-1]['content']}" + + +def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str: + return dedent( + f"""\ + [INST] + You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} + Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. + [/INST]""" + ) |