summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/util/strings.py
blob: f2b6035fa87782508e71883241fcec6de3fb26b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Tuple


def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:
    lines = s.splitlines()
    if len(lines) == 0:
        return "", ""

    # Longest common whitespace prefix
    lcp = lines[0].split(lines[0].strip())[0]
    # Iterate through the lines
    for i in range(1, len(lines)):
        # Empty lines are wildcards
        if lines[i].strip() == "":
            continue  # hey that's us!
        # Iterate through the leading whitespace characters of the current line
        for j in range(0, len(lcp)):
            # If it doesn't have the same whitespace as lcp, then update lcp
            if j >= len(lines[i]) or lcp[j] != lines[i][j]:
                lcp = lcp[:j]
                if lcp == "":
                    return s, ""
                break

    return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp


def strip_code_block(s: str) -> str:
    """
    Strips the code block from a string, if it has one.
    """
    if s.startswith("```\n") and s.endswith("\n```"):
        return s[4:-4]
    elif s.startswith("```") and s.endswith("```"):
        return s[3:-3]
    elif s.startswith("`") and s.endswith("`"):
        return s[1:-1]
    return s


def remove_quotes_and_escapes(output: str) -> str:
    """
    Clean up the output of the completion API, removing unnecessary escapes and quotes
    """
    output = output.strip()

    # Replace smart quotes
    output = output.replace("“", '"')
    output = output.replace("”", '"')
    output = output.replace("‘", "'")
    output = output.replace("’", "'")

    # Remove escapes
    output = output.replace('\\"', '"')
    output = output.replace("\\'", "'")
    output = output.replace("\\n", "\n")
    output = output.replace("\\t", "\t")
    output = output.replace("\\\\", "\\")
    if (output.startswith('"') and output.endswith('"')) or (
        output.startswith("'") and output.endswith("'")
    ):
        output = output[1:-1]

    return output