summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/models/main.py
blob: 081ec4af328a669e68c4a320e83cf9080d5b339d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from abc import ABC
from typing import List, Union
from pydantic import BaseModel, root_validator
from functools import total_ordering


@total_ordering
class Position(BaseModel):
    line: int
    character: int

    def __hash__(self):
        return hash((self.line, self.character))

    def __eq__(self, other: "Position") -> bool:
        return self.line == other.line and self.character == other.character

    def __lt__(self, other: "Position") -> bool:
        if self.line < other.line:
            return True
        elif self.line == other.line:
            return self.character < other.character
        else:
            return False

    @staticmethod
    def from_index(string: str, index: int) -> "Position":
        """Convert index in string to line and character"""
        line = string.count("\n", 0, index)
        if line == 1:
            character = index
        else:
            character = index - string.rindex("\n", 0, index) - 1

        return Position(line=line, character=character)


class Range(BaseModel):
    """A range in a file. 0-indexed."""
    start: Position
    end: Position

    def __hash__(self):
        return hash((self.start, self.end))

    def union(self, other: "Range") -> "Range":
        return Range(
            start=min(self.start, other.start),
            end=max(self.end, other.end),
        )

    def is_empty(self) -> bool:
        return self.start == self.end

    def overlaps_with(self, other: "Range") -> bool:
        return not (self.end < other.start or self.start > other.end)

    @staticmethod
    def from_indices(string: str, start_index: int, end_index: int) -> "Range":
        return Range(
            start=Position.from_index(string, start_index),
            end=Position.from_index(string, end_index)
        )

    @staticmethod
    def from_shorthand(start_line: int, start_char: int, end_line: int, end_char: int) -> "Range":
        return Range(
            start=Position(
                line=start_line,
                character=start_char
            ),
            end=Position(
                line=end_line,
                character=end_char
            )
        )

    @staticmethod
    def from_entire_file(content: str) -> "Range":
        lines = content.splitlines()
        if len(lines) == 0:
            return Range.from_shorthand(0, 0, 0, 0)
        return Range.from_shorthand(0, 0, len(lines) - 1, len(lines[-1]) - 1)

    @staticmethod
    def from_snippet_in_file(content: str, snippet: str) -> "Range":
        start_index = content.index(snippet)
        end_index = start_index + len(snippet)
        return Range.from_indices(content, start_index, end_index)


class AbstractModel(ABC, BaseModel):
    @root_validator(pre=True)
    def check_is_subclass(cls, values):
        if not issubclass(cls, AbstractModel):
            raise TypeError(
                "AbstractModel subclasses must be subclasses of AbstractModel")


class TracebackFrame(BaseModel):
    filepath: str
    lineno: int
    function: str
    code: Union[str, None]

    def __eq__(self, other):
        return self.filepath == other.filepath and self.lineno == other.lineno and self.function == other.function


class Traceback(BaseModel):
    frames: List[TracebackFrame]
    message: str
    error_type: str
    full_traceback: Union[str, None]

    @classmethod
    def from_tbutil_parsed_exc(cls, tbutil_parsed_exc):
        return cls(
            frames=[
                TracebackFrame(
                    filepath=frame["filepath"],
                    lineno=frame["lineno"],
                    function=frame["funcname"],
                    code=frame["source_line"],
                )
                for frame in tbutil_parsed_exc.frames
            ],
            message=tbutil_parsed_exc.exc_msg,
            error_type=tbutil_parsed_exc.exc_type,
            full_traceback=tbutil_parsed_exc.to_string(),
        )