summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-19 12:28:32 -0700
committerNate Sesti <sestinj@gmail.com>2023-06-19 12:28:32 -0700
commit5c6ee8c2c712e7390cf76015280ad0424e85852b (patch)
tree1caa4968428fed35b3ffd7803e70899857bb53f2 /continuedev/src
parent3c0d07017e305ca4feca2b884f0e5cee8b04eed3 (diff)
downloadsncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.tar.gz
sncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.tar.bz2
sncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.zip
better prompt
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/models/filesystem.py2
-rw-r--r--continuedev/src/continuedev/models/main.py6
-rw-r--r--continuedev/src/continuedev/steps/core/core.py59
3 files changed, 49 insertions, 18 deletions
diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py
index ede636c5..b709dd21 100644
--- a/continuedev/src/continuedev/models/filesystem.py
+++ b/continuedev/src/continuedev/models/filesystem.py
@@ -100,7 +100,7 @@ class FileSystem(AbstractModel):
@classmethod
def read_range_in_str(self, s: str, r: Range) -> str:
- lines = s.splitlines()[r.start.line:r.end.line + 1]
+ lines = s.split("\n")[r.start.line:r.end.line + 1]
if len(lines) == 0:
return ""
diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py
index fceba284..c9011b29 100644
--- a/continuedev/src/continuedev/models/main.py
+++ b/continuedev/src/continuedev/models/main.py
@@ -76,6 +76,12 @@ class Range(BaseModel):
def overlaps_with(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)
+ def to_full_lines(self) -> "Range":
+ return Range(
+ start=Position(line=self.start.line, character=0),
+ end=Position(line=self.end.line + 1, character=0)
+ )
+
@staticmethod
def from_indices(string: str, start_index: int, end_index: int) -> "Range":
return Range(
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 237171a6..99c3b867 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -116,37 +116,45 @@ class DefaultModelEditCodeStep(Step):
name: str = "Editing Code"
hide = False
_prompt: str = dedent("""\
- Take the file prefix and suffix into account, but only rewrite the commit before as specified in the commit message. Here's an example:
+ Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. Here's an example:
<file_prefix>
a = 5
b = 4
+ </file_prefix>
+ <code_to_edit>
+ def sum():
+ return a + b
+ </code_to_edit>
<file_suffix>
def mul(a, b):
return a * b
- <commit_before>
- def sum():
- return a + b
- <commit_msg>
+ </file_suffix>
+ <user_request>
Make a and b parameters of sum
- <commit_after>
+ </user_request>
+ <modified_code_to_edit>
def sum(a, b):
return a + b
- <|endoftext|>
+ </modified_code_to_edit>
- Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after".
+ Now complete the real thing. Do NOT rewrite anything in the file_prefix or file_suffix tags, but rewrite everything inside the code_to_edit tags in order to fulfill the user request. Do NOT preface your answer or write anything other than code.
<file_prefix>
{file_prefix}
+ </file_prefix>
+ <code_to_edit>
+ {code}
+ </code_to_edit>
<file_suffix>
{file_suffix}
- <commit_before>
- {code}
- <commit_msg>
+ </file_suffix>
+ <user_request>
{user_request}
- <commit_after>
+ </user_request>
+ <modified_code_to_edit>
""")
_prompt_and_completion: str = ""
@@ -162,7 +170,11 @@ class DefaultModelEditCodeStep(Step):
await sdk.update_ui()
rif_with_contents = []
- for range_in_file in self.range_in_files:
+ for range_in_file in map(lambda x: RangeInFile(
+ filepath=x.filepath,
+ # Only consider the range line-by-line. Maybe later don't if it's only a single line.
+ range=x.range.to_full_lines()
+ ), self.range_in_files):
file_contents = await sdk.ide.readRangeInFile(range_in_file)
rif_with_contents.append(
RangeInFileWithContents.from_range_in_file(range_in_file, file_contents))
@@ -225,16 +237,25 @@ class DefaultModelEditCodeStep(Step):
raise Exception("Unknown default model")
- code_before = "".join(
+ code_before = "\n".join(
full_file_contents_lst[cur_start_line:max_start_line])
- code_after = "".join(
- full_file_contents_lst[min_end_line:cur_end_line])
+ code_after = "\n".join(
+ full_file_contents_lst[min_end_line:cur_end_line - 1])
segs = [code_before, code_after]
+ if segs[0].strip() == "":
+ segs[0] = segs[0].strip()
+ if segs[1].strip() == "":
+ segs[1] = segs[1].strip()
prompt = self._prompt.format(
code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1])
+ if segs[0].strip() == "":
+ prompt = prompt.replace("<file_prefix>\n", "")
+ if segs[1].strip() == "":
+ prompt = prompt.replace("\n<file_suffix>", "")
+
lines = []
unfinished_line = ""
i = 0
@@ -266,7 +287,9 @@ class DefaultModelEditCodeStep(Step):
lines.extend(chunk_lines)
for line in chunk_lines:
- if i < len(original_lines) and line == original_lines[i]:
+ if "</modified_code_to_edit>" in line:
+ break
+ elif i < len(original_lines) and line == original_lines[i]:
i += 1
continue
@@ -275,6 +298,8 @@ class DefaultModelEditCodeStep(Step):
# Add the unfinished line
if unfinished_line != "":
+ unfinished_line = unfinished_line.removesuffix(
+ "</modified_code_to_edit>")
if not i < len(original_lines) or not unfinished_line == original_lines[i]:
await add_line(i, unfinished_line)
lines.append(unfinished_line)