summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/plugins/steps/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/plugins/steps/main.py')
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py80
1 files changed, 69 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ca15aaab..cd3b30e0 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,4 +1,5 @@
import os
+import urllib.parse
from textwrap import dedent
from typing import Coroutine, List, Optional, Union
@@ -235,6 +236,58 @@ class StarCoderEditHighlightedCodeStep(Step):
await sdk.ide.setFileOpen(rif.filepath)
+class EditAlreadyEditedRangeStep(Step):
+ hide = True
+ model: Optional[LLM] = None
+ range_in_file: RangeInFile
+
+ user_input: str
+
+ _prompt = dedent(
+ """\
+ You were previously asked to edit this code. The request was:
+
+ "{prev_user_input}"
+
+ And you generated this diff:
+
+ {diff}
+
+ Could you please re-edit this code to follow these secondary instructions?
+
+ "{user_input}"
+ """
+ )
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ if os.path.basename(self.range_in_file.filepath) in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ ):
+ decoded_basename = urllib.parse.unquote(
+ os.path.basename(self.range_in_file.filepath)
+ )
+ self.range_in_file.filepath = decoded_basename
+
+ self.range_in_file.range = sdk.context.get("last_edit_range")
+
+ if self.range_in_file.range.start == self.range_in_file.range.end:
+ self.range_in_file.range = Range.from_entire_file(
+ await sdk.ide.readFile(self.range_in_file.filepath)
+ )
+
+ await sdk.run_step(
+ DefaultModelEditCodeStep(
+ model=self.model,
+ user_input=self._prompt.format(
+ prev_user_input=sdk.context.get("last_edit_user_input"),
+ diff=sdk.context.get("last_edit_diff"),
+ user_input=self.user_input,
+ ),
+ range_in_files=[self.range_in_file],
+ )
+ )
+
+
class EditHighlightedCodeStep(Step):
user_input: str = Field(
...,
@@ -258,13 +311,6 @@ class EditHighlightedCodeStep(Step):
highlighted_code = await sdk.ide.getHighlightedCode()
if highlighted_code is not None:
for rif in highlighted_code:
- if os.path.dirname(rif.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- raise ContinueCustomException(
- message="Please accept or reject the change before making another edit in this file.",
- title="Accept/Reject First",
- )
if rif.range.start == rif.range.end:
range_in_files.append(
RangeInFileWithContents.from_range_in_file(rif, "")
@@ -289,10 +335,22 @@ class EditHighlightedCodeStep(Step):
)
for range_in_file in range_in_files:
- if os.path.dirname(range_in_file.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- self.description = "Please accept or reject the change before making another edit in this file."
+ # Check whether re-editing
+ if (
+ os.path.dirname(range_in_file.filepath)
+ == os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ or urllib.parse.quote_plus(range_in_file.filepath)
+ in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ )
+ ) and sdk.context.get("last_edit_user_input") is not None:
+ await sdk.run_step(
+ EditAlreadyEditedRangeStep(
+ range_in_file=range_in_file,
+ user_input=self.user_input,
+ model=self.model,
+ )
+ )
return
args = {