summaryrefslogtreecommitdiff
path: root/server/continuedev/plugins/steps/refactor.py
blob: 56e9e09e959effdf0d6e9fcf952434625d138e90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import asyncio
from typing import List, Optional

from ripgrepy import Ripgrepy

from ...core.main import Step
from ...core.models import Models
from ...core.sdk import ContinueSDK
from ...libs.llm.prompts.edit import simplified_edit_prompt
from ...libs.util.ripgrep import get_rg_path
from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block
from ...libs.util.templating import render_prompt_template
from ...models.filesystem import RangeInFile
from ...models.filesystem_edit import FileEdit
from ...models.main import PositionInFile, Range


class RefactorReferencesStep(Step):
    name: str = "Refactor references of a symbol"
    user_input: str
    symbol_location: PositionInFile

    async def describe(self, models: Models):
        return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`"

    async def run(self, sdk: ContinueSDK):
        while sdk.lsp is None or not sdk.lsp.ready:
            await asyncio.sleep(0.1)

        references = await sdk.lsp.find_references(
            self.symbol_location.position, self.symbol_location.filepath, False
        )
        await sdk.run_step(
            ParallelEditStep(user_input=self.user_input, range_in_files=references)
        )


class RefactorBySearchStep(Step):
    name: str = "Refactor by search"

    pattern: str
    user_input: str

    rg_path: Optional[str] = None
    "Optional path to ripgrep executable"

    def get_range_for_result(self, result) -> RangeInFile:
        pass

    async def run(self, sdk: ContinueSDK):
        rg = Ripgrepy(
            self.pattern,
            sdk.ide.workspace_directory,
            rg_path=self.rg_path or get_rg_path(),
        )

        results = rg.I().context(2).run()
        range_in_files = [self.get_range_for_result(result) for result in results]

        await sdk.run_step(
            ParallelEditStep(user_input=self.user_input, range_in_files=range_in_files)
        )


class ParallelEditStep(Step):
    name: str = "Edit multiple ranges in parallel"
    user_input: str
    range_in_files: List[RangeInFile]

    hide: bool = True

    async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile):
        # TODO: Can use folding info to get a more intuitively shaped range
        expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file)
        if (
            expanded_range is None
            or expanded_range.range.start.line != range_in_file.range.start.line
        ):
            expanded_range = Range.from_shorthand(
                range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0
            )
        else:
            expanded_range = expanded_range.range

        new_rif = RangeInFile(
            filepath=range_in_file.filepath,
            range=expanded_range,
        )
        code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif)

        # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit)

        prompt = render_prompt_template(
            simplified_edit_prompt,
            history=[],
            other_data={
                "code_to_edit": code_to_edit,
                "user_input": self.user_input,
            },
        )
        print(prompt + "\n\n-------------------\n\n")

        new_code = await sdk.models.edit.complete(prompt=prompt)
        new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n"
        # new_code = (
        #     "\n".join([common_whitespace + line for line in new_code.split("\n")])
        #     + "\n"
        # )

        print(new_code + "\n\n-------------------\n\n")

        await sdk.ide.applyFileSystemEdit(
            FileEdit(
                filepath=range_in_file.filepath,
                range=expanded_range,
                replacement=new_code,
            )
        )

    async def edit_file(self, sdk: ContinueSDK, filepath: str):
        ranges_in_file = [
            range_in_file
            for range_in_file in self.range_in_files
            if range_in_file.filepath == filepath
        ]
        # Sort in reverse order so that we don't mess up the ranges
        ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True)
        for i in range(len(ranges_in_file)):
            await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i])

    async def run(self, sdk: ContinueSDK):
        tasks = []
        for filepath in set([rif.filepath for rif in self.range_in_files]):
            tasks.append(self.edit_file(sdk=sdk, filepath=filepath))

        await asyncio.gather(*tasks)