summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/steps/core/core.py79
1 files changed, 41 insertions, 38 deletions
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 99c3b867..a780cedd 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -116,31 +116,39 @@ class DefaultModelEditCodeStep(Step):
name: str = "Editing Code"
hide = False
_prompt: str = dedent("""\
- Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. Here's an example:
+ Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. The code you write in modified_code_to_edit will replace the code between the code_to_edit tags. Do NOT preface your answer or write anything other than code.
+
+ Example:
<file_prefix>
- a = 5
- b = 4
+ class Database:
+ def __init__(self):
+ self._data = {{}}
+
+ def get(self, key):
+ return self._data[key]
</file_prefix>
<code_to_edit>
- def sum():
- return a + b
+ def set(self, key, value):
+ self._data[key] = value
</code_to_edit>
<file_suffix>
- def mul(a, b):
- return a * b
+ def clear_all():
+ self._data = {{}}
</file_suffix>
<user_request>
- Make a and b parameters of sum
+ Raise an error if the key already exists.
</user_request>
<modified_code_to_edit>
- def sum(a, b):
- return a + b
+ def set(self, key, value):
+ if key in self._data:
+ raise KeyError(f"Key {{key}} already exists")
+ self._data[key] = value
</modified_code_to_edit>
- Now complete the real thing. Do NOT rewrite anything in the file_prefix or file_suffix tags, but rewrite everything inside the code_to_edit tags in order to fulfill the user request. Do NOT preface your answer or write anything other than code.
+ Main task:
<file_prefix>
{file_prefix}
@@ -260,9 +268,13 @@ class DefaultModelEditCodeStep(Step):
unfinished_line = ""
i = 0
original_lines = rif.contents.split("\n")
- lines_to_highlight = []
async def add_line(i: int, line: str):
+ if i == 0:
+ # First line indentation, because the model will assume that it is replacing in this way
+ line = original_lines[0].replace(
+ original_lines[0].strip(), "") + line
+
range = Range.from_shorthand(
rif.range.start.line + i, rif.range.start.character if i == 0 else 0, rif.range.start.line + i + 1, 0)
await sdk.ide.applyFileSystemEdit(FileEdit(
@@ -270,11 +282,6 @@ class DefaultModelEditCodeStep(Step):
range=range,
replacement=line + "\n"
))
- lines_to_highlight.append(rif.range.start.line + i)
- # await sdk.ide.highlightCode(RangeInFile(
- # filepath=rif.filepath,
- # range=range
- # ))
async for chunk in model_to_use.stream_chat(prompt, with_history=await sdk.get_chat_context()):
chunk_lines = chunk.split("\n")
@@ -289,6 +296,8 @@ class DefaultModelEditCodeStep(Step):
for line in chunk_lines:
if "</modified_code_to_edit>" in line:
break
+ elif "<modified_code_to_edit>" in line or "<file_prefix>" in line or "</file_prefix>" in line or "<file_suffix>" in line or "</file_suffix>" in line or "<user_request>" in line or "</user_request>" in line or "<code_to_edit>" in line or "</code_to_edit>" in line:
+ continue
elif i < len(original_lines) and line == original_lines[i]:
i += 1
continue
@@ -299,7 +308,7 @@ class DefaultModelEditCodeStep(Step):
# Add the unfinished line
if unfinished_line != "":
unfinished_line = unfinished_line.removesuffix(
- "</modified_code_to_edit>")
+ "</modified_code_to_edit>").removesuffix("</code_to_edit>")
if not i < len(original_lines) or not unfinished_line == original_lines[i]:
await add_line(i, unfinished_line)
lines.append(unfinished_line)
@@ -308,37 +317,31 @@ class DefaultModelEditCodeStep(Step):
# Remove the leftover original lines
while i < len(original_lines):
range = Range.from_shorthand(
- rif.range.start.line + i, rif.range.start.character, rif.range.start.line + i, len(original_lines[i]))
+ rif.range.start.line + i, rif.range.start.character, rif.range.start.line + i, len(original_lines[i]) + 1)
await sdk.ide.applyFileSystemEdit(FileEdit(
filepath=rif.filepath,
range=range,
replacement=""
))
- # await sdk.ide.highlightCode(RangeInFile(
- # filepath=rif.filepath,
- # range=range
- # ))
i += 1
completion = "\n".join(lines)
- # eot_token = "<|endoftext|>"
- # completion = completion.removesuffix(eot_token)
-
- # # Remove tags and If it accidentally includes prefix or suffix, remove it
- # if completion.strip().startswith("```"):
- # completion = completion.strip().removeprefix("```").removesuffix("```")
- # completion = completion.replace("<file_prefix>", "").replace("<file_suffix>", "").replace(
- # "<commit_before>", "").replace("<commit_msg>", "").replace("<commit_after>", "")
- # completion = completion.removeprefix(segs[0])
- # completion = completion.removesuffix(segs[1])
self._prompt_and_completion += prompt + completion
- # await sdk.ide.applyFileSystemEdit(FileEdit(
- # filepath=rif.filepath,
- # range=rif.range,
- # replacement=completion
- # ))
+ diff = list(difflib.ndiff(rif.contents.splitlines(
+ keepends=True), completion.splitlines(keepends=True)))
+
+ lines_to_highlight = set()
+ index = 0
+ for line in diff:
+ if line.startswith("-"):
+ pass
+ elif line.startswith("+"):
+ lines_to_highlight.add(index + rif.range.start.line)
+ index += 1
+ elif line.startswith(" "):
+ index += 1
current_hl_start = None
last_hl = None