diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-19 12:28:32 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-19 12:28:32 -0700 |
commit | 5c6ee8c2c712e7390cf76015280ad0424e85852b (patch) | |
tree | 1caa4968428fed35b3ffd7803e70899857bb53f2 /continuedev | |
parent | 3c0d07017e305ca4feca2b884f0e5cee8b04eed3 (diff) | |
download | sncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.tar.gz sncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.tar.bz2 sncontinue-5c6ee8c2c712e7390cf76015280ad0424e85852b.zip |
better prompt
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/main.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 59 |
3 files changed, 49 insertions, 18 deletions
diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index ede636c5..b709dd21 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -100,7 +100,7 @@ class FileSystem(AbstractModel): @classmethod def read_range_in_str(self, s: str, r: Range) -> str: - lines = s.splitlines()[r.start.line:r.end.line + 1] + lines = s.split("\n")[r.start.line:r.end.line + 1] if len(lines) == 0: return "" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index fceba284..c9011b29 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -76,6 +76,12 @@ class Range(BaseModel): def overlaps_with(self, other: "Range") -> bool: return not (self.end < other.start or self.start > other.end) + def to_full_lines(self) -> "Range": + return Range( + start=Position(line=self.start.line, character=0), + end=Position(line=self.end.line + 1, character=0) + ) + @staticmethod def from_indices(string: str, start_index: int, end_index: int) -> "Range": return Range( diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 237171a6..99c3b867 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -116,37 +116,45 @@ class DefaultModelEditCodeStep(Step): name: str = "Editing Code" hide = False _prompt: str = dedent("""\ - Take the file prefix and suffix into account, but only rewrite the commit before as specified in the commit message. Here's an example: + Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. Here's an example: <file_prefix> a = 5 b = 4 + </file_prefix> + <code_to_edit> + def sum(): + return a + b + </code_to_edit> <file_suffix> def mul(a, b): return a * b - <commit_before> - def sum(): - return a + b - <commit_msg> + </file_suffix> + <user_request> Make a and b parameters of sum - <commit_after> + </user_request> + <modified_code_to_edit> def sum(a, b): return a + b - <|endoftext|> + </modified_code_to_edit> - Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after". + Now complete the real thing. Do NOT rewrite anything in the file_prefix or file_suffix tags, but rewrite everything inside the code_to_edit tags in order to fulfill the user request. Do NOT preface your answer or write anything other than code. <file_prefix> {file_prefix} + </file_prefix> + <code_to_edit> + {code} + </code_to_edit> <file_suffix> {file_suffix} - <commit_before> - {code} - <commit_msg> + </file_suffix> + <user_request> {user_request} - <commit_after> + </user_request> + <modified_code_to_edit> """) _prompt_and_completion: str = "" @@ -162,7 +170,11 @@ class DefaultModelEditCodeStep(Step): await sdk.update_ui() rif_with_contents = [] - for range_in_file in self.range_in_files: + for range_in_file in map(lambda x: RangeInFile( + filepath=x.filepath, + # Only consider the range line-by-line. Maybe later don't if it's only a single line. + range=x.range.to_full_lines() + ), self.range_in_files): file_contents = await sdk.ide.readRangeInFile(range_in_file) rif_with_contents.append( RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) @@ -225,16 +237,25 @@ class DefaultModelEditCodeStep(Step): raise Exception("Unknown default model") - code_before = "".join( + code_before = "\n".join( full_file_contents_lst[cur_start_line:max_start_line]) - code_after = "".join( - full_file_contents_lst[min_end_line:cur_end_line]) + code_after = "\n".join( + full_file_contents_lst[min_end_line:cur_end_line - 1]) segs = [code_before, code_after] + if segs[0].strip() == "": + segs[0] = segs[0].strip() + if segs[1].strip() == "": + segs[1] = segs[1].strip() prompt = self._prompt.format( code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) + if segs[0].strip() == "": + prompt = prompt.replace("<file_prefix>\n", "") + if segs[1].strip() == "": + prompt = prompt.replace("\n<file_suffix>", "") + lines = [] unfinished_line = "" i = 0 @@ -266,7 +287,9 @@ class DefaultModelEditCodeStep(Step): lines.extend(chunk_lines) for line in chunk_lines: - if i < len(original_lines) and line == original_lines[i]: + if "</modified_code_to_edit>" in line: + break + elif i < len(original_lines) and line == original_lines[i]: i += 1 continue @@ -275,6 +298,8 @@ class DefaultModelEditCodeStep(Step): # Add the unfinished line if unfinished_line != "": + unfinished_line = unfinished_line.removesuffix( + "</modified_code_to_edit>") if not i < len(original_lines) or not unfinished_line == original_lines[i]: await add_line(i, unfinished_line) lines.append(unfinished_line) |