summaryrefslogtreecommitdiff
path: root/server/continuedev/plugins/steps/refactor.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/plugins/steps/refactor.py')
-rw-r--r--server/continuedev/plugins/steps/refactor.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/server/continuedev/plugins/steps/refactor.py b/server/continuedev/plugins/steps/refactor.py
new file mode 100644
index 00000000..56e9e09e
--- /dev/null
+++ b/server/continuedev/plugins/steps/refactor.py
@@ -0,0 +1,136 @@
+import asyncio
+from typing import List, Optional
+
+from ripgrepy import Ripgrepy
+
+from ...core.main import Step
+from ...core.models import Models
+from ...core.sdk import ContinueSDK
+from ...libs.llm.prompts.edit import simplified_edit_prompt
+from ...libs.util.ripgrep import get_rg_path
+from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block
+from ...libs.util.templating import render_prompt_template
+from ...models.filesystem import RangeInFile
+from ...models.filesystem_edit import FileEdit
+from ...models.main import PositionInFile, Range
+
+
+class RefactorReferencesStep(Step):
+ name: str = "Refactor references of a symbol"
+ user_input: str
+ symbol_location: PositionInFile
+
+ async def describe(self, models: Models):
+ return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`"
+
+ async def run(self, sdk: ContinueSDK):
+ while sdk.lsp is None or not sdk.lsp.ready:
+ await asyncio.sleep(0.1)
+
+ references = await sdk.lsp.find_references(
+ self.symbol_location.position, self.symbol_location.filepath, False
+ )
+ await sdk.run_step(
+ ParallelEditStep(user_input=self.user_input, range_in_files=references)
+ )
+
+
+class RefactorBySearchStep(Step):
+ name: str = "Refactor by search"
+
+ pattern: str
+ user_input: str
+
+ rg_path: Optional[str] = None
+ "Optional path to ripgrep executable"
+
+ def get_range_for_result(self, result) -> RangeInFile:
+ pass
+
+ async def run(self, sdk: ContinueSDK):
+ rg = Ripgrepy(
+ self.pattern,
+ sdk.ide.workspace_directory,
+ rg_path=self.rg_path or get_rg_path(),
+ )
+
+ results = rg.I().context(2).run()
+ range_in_files = [self.get_range_for_result(result) for result in results]
+
+ await sdk.run_step(
+ ParallelEditStep(user_input=self.user_input, range_in_files=range_in_files)
+ )
+
+
+class ParallelEditStep(Step):
+ name: str = "Edit multiple ranges in parallel"
+ user_input: str
+ range_in_files: List[RangeInFile]
+
+ hide: bool = True
+
+ async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile):
+ # TODO: Can use folding info to get a more intuitively shaped range
+ expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file)
+ if (
+ expanded_range is None
+ or expanded_range.range.start.line != range_in_file.range.start.line
+ ):
+ expanded_range = Range.from_shorthand(
+ range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0
+ )
+ else:
+ expanded_range = expanded_range.range
+
+ new_rif = RangeInFile(
+ filepath=range_in_file.filepath,
+ range=expanded_range,
+ )
+ code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif)
+
+ # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit)
+
+ prompt = render_prompt_template(
+ simplified_edit_prompt,
+ history=[],
+ other_data={
+ "code_to_edit": code_to_edit,
+ "user_input": self.user_input,
+ },
+ )
+ print(prompt + "\n\n-------------------\n\n")
+
+ new_code = await sdk.models.edit.complete(prompt=prompt)
+ new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n"
+ # new_code = (
+ # "\n".join([common_whitespace + line for line in new_code.split("\n")])
+ # + "\n"
+ # )
+
+ print(new_code + "\n\n-------------------\n\n")
+
+ await sdk.ide.applyFileSystemEdit(
+ FileEdit(
+ filepath=range_in_file.filepath,
+ range=expanded_range,
+ replacement=new_code,
+ )
+ )
+
+ async def edit_file(self, sdk: ContinueSDK, filepath: str):
+ ranges_in_file = [
+ range_in_file
+ for range_in_file in self.range_in_files
+ if range_in_file.filepath == filepath
+ ]
+ # Sort in reverse order so that we don't mess up the ranges
+ ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True)
+ for i in range(len(ranges_in_file)):
+ await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i])
+
+ async def run(self, sdk: ContinueSDK):
+ tasks = []
+ for filepath in set([rif.filepath for rif in self.range_in_files]):
+ tasks.append(self.edit_file(sdk=sdk, filepath=filepath))
+
+ await asyncio.gather(*tasks)