summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/main.py10
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py5
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py80
3 files changed, 78 insertions, 17 deletions
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index ace1ad60..63a3e6a9 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -1,5 +1,5 @@
import json
-from typing import Coroutine, Dict, List, Literal, Optional, Union
+from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, validator
from pydantic.schema import schema
@@ -401,13 +401,13 @@ class Validator(Step):
class Context:
- key_value: Dict[str, str] = {}
+ key_value: Dict[str, Any] = {}
- def set(self, key: str, value: str):
+ def set(self, key: str, value: Any):
self.key_value[key] = value
- def get(self, key: str) -> str:
- return self.key_value[key]
+ def get(self, key: str) -> Any:
+ return self.key_value.get(key, None)
class ContinueCustomException(Exception):
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index bf5eb144..1d7ffdd7 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -812,7 +812,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user
rif_dict[rif.filepath] = rif.contents
for rif in rif_with_contents:
- await sdk.ide.setFileOpen(rif.filepath)
await sdk.ide.setSuggestionsLocked(rif.filepath, True)
await self.stream_rif(rif, sdk)
await sdk.ide.setSuggestionsLocked(rif.filepath, False)
@@ -842,6 +841,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user
self.description += chunk
await sdk.update_ui()
+ sdk.context.set("last_edit_user_input", self.user_input)
+ sdk.context.set("last_edit_diff", changes)
+ sdk.context.set("last_edit_range", self.range_in_files[-1].range)
+
class EditFileStep(Step):
filepath: str
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ca15aaab..cd3b30e0 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,4 +1,5 @@
import os
+import urllib.parse
from textwrap import dedent
from typing import Coroutine, List, Optional, Union
@@ -235,6 +236,58 @@ class StarCoderEditHighlightedCodeStep(Step):
await sdk.ide.setFileOpen(rif.filepath)
+class EditAlreadyEditedRangeStep(Step):
+ hide = True
+ model: Optional[LLM] = None
+ range_in_file: RangeInFile
+
+ user_input: str
+
+ _prompt = dedent(
+ """\
+ You were previously asked to edit this code. The request was:
+
+ "{prev_user_input}"
+
+ And you generated this diff:
+
+ {diff}
+
+ Could you please re-edit this code to follow these secondary instructions?
+
+ "{user_input}"
+ """
+ )
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ if os.path.basename(self.range_in_file.filepath) in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ ):
+ decoded_basename = urllib.parse.unquote(
+ os.path.basename(self.range_in_file.filepath)
+ )
+ self.range_in_file.filepath = decoded_basename
+
+ self.range_in_file.range = sdk.context.get("last_edit_range")
+
+ if self.range_in_file.range.start == self.range_in_file.range.end:
+ self.range_in_file.range = Range.from_entire_file(
+ await sdk.ide.readFile(self.range_in_file.filepath)
+ )
+
+ await sdk.run_step(
+ DefaultModelEditCodeStep(
+ model=self.model,
+ user_input=self._prompt.format(
+ prev_user_input=sdk.context.get("last_edit_user_input"),
+ diff=sdk.context.get("last_edit_diff"),
+ user_input=self.user_input,
+ ),
+ range_in_files=[self.range_in_file],
+ )
+ )
+
+
class EditHighlightedCodeStep(Step):
user_input: str = Field(
...,
@@ -258,13 +311,6 @@ class EditHighlightedCodeStep(Step):
highlighted_code = await sdk.ide.getHighlightedCode()
if highlighted_code is not None:
for rif in highlighted_code:
- if os.path.dirname(rif.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- raise ContinueCustomException(
- message="Please accept or reject the change before making another edit in this file.",
- title="Accept/Reject First",
- )
if rif.range.start == rif.range.end:
range_in_files.append(
RangeInFileWithContents.from_range_in_file(rif, "")
@@ -289,10 +335,22 @@ class EditHighlightedCodeStep(Step):
)
for range_in_file in range_in_files:
- if os.path.dirname(range_in_file.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- self.description = "Please accept or reject the change before making another edit in this file."
+ # Check whether re-editing
+ if (
+ os.path.dirname(range_in_file.filepath)
+ == os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ or urllib.parse.quote_plus(range_in_file.filepath)
+ in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ )
+ ) and sdk.context.get("last_edit_user_input") is not None:
+ await sdk.run_step(
+ EditAlreadyEditedRangeStep(
+ range_in_file=range_in_file,
+ user_input=self.user_input,
+ model=self.model,
+ )
+ )
return
args = {