diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/main.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 59 | 
3 files changed, 49 insertions, 18 deletions
| diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index ede636c5..b709dd21 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -100,7 +100,7 @@ class FileSystem(AbstractModel):      @classmethod      def read_range_in_str(self, s: str, r: Range) -> str: -        lines = s.splitlines()[r.start.line:r.end.line + 1] +        lines = s.split("\n")[r.start.line:r.end.line + 1]          if len(lines) == 0:              return "" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index fceba284..c9011b29 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -76,6 +76,12 @@ class Range(BaseModel):      def overlaps_with(self, other: "Range") -> bool:          return not (self.end < other.start or self.start > other.end) +    def to_full_lines(self) -> "Range": +        return Range( +            start=Position(line=self.start.line, character=0), +            end=Position(line=self.end.line + 1, character=0) +        ) +      @staticmethod      def from_indices(string: str, start_index: int, end_index: int) -> "Range":          return Range( diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 237171a6..99c3b867 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -116,37 +116,45 @@ class DefaultModelEditCodeStep(Step):      name: str = "Editing Code"      hide = False      _prompt: str = dedent("""\ -        Take the file prefix and suffix into account, but only rewrite the commit before as specified in the commit message. Here's an example: +        Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. Here's an example:          <file_prefix>          a = 5          b = 4 +        </file_prefix> +        <code_to_edit> +        def sum(): +            return a + b +        </code_to_edit>          <file_suffix>          def mul(a, b):              return a * b -        <commit_before> -        def sum(): -            return a + b -        <commit_msg> +        </file_suffix> +        <user_request>          Make a and b parameters of sum -        <commit_after> +        </user_request> +        <modified_code_to_edit>          def sum(a, b):              return a + b -        <|endoftext|> +        </modified_code_to_edit> -        Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after". +        Now complete the real thing. Do NOT rewrite anything in the file_prefix or file_suffix tags, but rewrite everything inside the code_to_edit tags in order to fulfill the user request. Do NOT preface your answer or write anything other than code.          <file_prefix>          {file_prefix} +        </file_prefix> +        <code_to_edit> +        {code} +        </code_to_edit>          <file_suffix>          {file_suffix} -        <commit_before> -        {code} -        <commit_msg> +        </file_suffix> +        <user_request>          {user_request} -        <commit_after> +        </user_request> +        <modified_code_to_edit>          """)      _prompt_and_completion: str = "" @@ -162,7 +170,11 @@ class DefaultModelEditCodeStep(Step):          await sdk.update_ui()          rif_with_contents = [] -        for range_in_file in self.range_in_files: +        for range_in_file in map(lambda x: RangeInFile( +            filepath=x.filepath, +            # Only consider the range line-by-line. Maybe later don't if it's only a single line. +            range=x.range.to_full_lines() +        ), self.range_in_files):              file_contents = await sdk.ide.readRangeInFile(range_in_file)              rif_with_contents.append(                  RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) @@ -225,16 +237,25 @@ class DefaultModelEditCodeStep(Step):                  raise Exception("Unknown default model") -            code_before = "".join( +            code_before = "\n".join(                  full_file_contents_lst[cur_start_line:max_start_line]) -            code_after = "".join( -                full_file_contents_lst[min_end_line:cur_end_line]) +            code_after = "\n".join( +                full_file_contents_lst[min_end_line:cur_end_line - 1])              segs = [code_before, code_after] +            if segs[0].strip() == "": +                segs[0] = segs[0].strip() +            if segs[1].strip() == "": +                segs[1] = segs[1].strip()              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) +            if segs[0].strip() == "": +                prompt = prompt.replace("<file_prefix>\n", "") +            if segs[1].strip() == "": +                prompt = prompt.replace("\n<file_suffix>", "") +              lines = []              unfinished_line = ""              i = 0 @@ -266,7 +287,9 @@ class DefaultModelEditCodeStep(Step):                  lines.extend(chunk_lines)                  for line in chunk_lines: -                    if i < len(original_lines) and line == original_lines[i]: +                    if "</modified_code_to_edit>" in line: +                        break +                    elif i < len(original_lines) and line == original_lines[i]:                          i += 1                          continue @@ -275,6 +298,8 @@ class DefaultModelEditCodeStep(Step):              # Add the unfinished line              if unfinished_line != "": +                unfinished_line = unfinished_line.removesuffix( +                    "</modified_code_to_edit>")                  if not i < len(original_lines) or not unfinished_line == original_lines[i]:                      await add_line(i, unfinished_line)                  lines.append(unfinished_line) | 
