summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/prompts/chat.py
blob: 49010229cd83d315f8746a9db3518b01535d8f8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from textwrap import dedent
from typing import Dict, List


def template_alpaca_messages(msgs: List[Dict[str, str]]) -> str:
    prompt = ""

    if msgs[0]["role"] == "system":
        prompt += f"{msgs[0]['content']}\n"
        msgs.pop(0)

    for msg in msgs:
        prompt += "### Instruction:\n" if msg["role"] == "user" else "### Response:\n"
        prompt += f"{msg['content']}\n"

    prompt += "### Response:\n"

    return prompt


def raw_input_template(msgs: List[Dict[str, str]]) -> str:
    return msgs[-1]["content"]


SQL_CODER_DEFAULT_SCHEMA = """\
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""


def _sqlcoder_template_messages(
    msgs: List[Dict[str, str]], schema: str = SQL_CODER_DEFAULT_SCHEMA
) -> str:
    question = msgs[-1]["content"]
    return f"""\
Your task is to convert a question into a SQL query, given a Postgres database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float

### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{schema}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
"""


def sqlcoder_template_messages(schema: str = SQL_CODER_DEFAULT_SCHEMA):
    if schema == "<MY_DATABASE_SCHEMA>" or schema == "":
        schema = SQL_CODER_DEFAULT_SCHEMA

    def fn(msgs):
        return _sqlcoder_template_messages(msgs, schema=schema)

    fn.__name__ = "sqlcoder_template_messages"
    return fn


def llama2_template_messages(msgs: List[Dict[str, str]]) -> str:
    if len(msgs) == 0:
        return ""

    if msgs[0]["role"] == "assistant":
        # These models aren't trained to handle assistant message coming first,
        # and typically these are just introduction messages from Continue
        msgs.pop(0)

    prompt = ""
    has_system = msgs[0]["role"] == "system"

    if has_system and msgs[0]["content"].strip() == "":
        has_system = False
        msgs = msgs[1:]

    if has_system:
        system_message = dedent(
            f"""\
                <<SYS>>
                {msgs[0]["content"]}
                <</SYS>>
                
                """
        )
        if len(msgs) > 1:
            prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
        else:
            prompt += f"[INST] {system_message} [/INST]"
            return

    for i in range(2 if has_system else 0, len(msgs)):
        if msgs[i]["role"] == "user":
            prompt += f"[INST] {msgs[i]['content']} [/INST]"
        else:
            prompt += msgs[i]["content"] + " "

    return prompt


def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str:
    return f"[INST] {msgs[-1]['content']}\n[/INST]"


def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str:
    return f" {msgs[-1]['content']}"


def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str:
    return dedent(
        f"""\
        [INST]
        You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']}
        Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag.
        [/INST]"""
    )