diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/main.py | 14 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 70 | 
3 files changed, 49 insertions, 37 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 39c0b69f..6f620ba0 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -34,7 +34,7 @@ class OpenAI(LLM):          return tiktoken.encoding_for_model(self.default_model)      def count_tokens(self, text: str): -        return len(self.__encoding_for_model.encode(text)) +        return len(self.__encoding_for_model.encode(text, disallowed_special=()))      def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):          tokens = tokens_for_completion diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 02c44aae..fceba284 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -1,5 +1,5 @@  from abc import ABC -from typing import List, Union +from typing import List, Union, Tuple  from pydantic import BaseModel, root_validator  from functools import total_ordering @@ -61,6 +61,18 @@ class Range(BaseModel):      def is_empty(self) -> bool:          return self.start == self.end +    def indices_in_string(self, string: str) -> Tuple[int, int]: +        """Get the start and end indicees of this range in the string""" +        lines = string.splitlines() +        if len(lines) == 0: +            return (0, 0) + +        start_index = sum( +            [len(line) + 1 for line in lines[:self.start.line]]) + self.start.character +        end_index = sum( +            [len(line) + 1 for line in lines[:self.end.line]]) + self.end.character +        return (start_index, end_index) +      def overlaps_with(self, other: "Range") -> bool:          return not (self.end < other.start or self.start > other.end) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index ff498b9b..f1fb229e 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -5,12 +5,13 @@ from textwrap import dedent  from typing import Coroutine, List, Union  from ...models.main import Range -from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str +from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str, line_by_line_diff  from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder  from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit  from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents  from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation  from ...core.main import Step, SequentialStep +import difflib  class ContinueSDK: @@ -135,7 +136,7 @@ class DefaultModelEditCodeStep(Step):              return a + b          <|endoftext|> -        Now complete the real thing. Do NOT rewrite the prefix or suffix. +        Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after".          <file_prefix>          {file_prefix} @@ -159,6 +160,7 @@ class DefaultModelEditCodeStep(Step):      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          self.name = self.user_input +        await sdk.update_ui()          rif_with_contents = []          for range_in_file in self.range_in_files: @@ -171,48 +173,46 @@ class DefaultModelEditCodeStep(Step):              rif_dict[rif.filepath] = rif.contents          for rif in rif_with_contents: +            await sdk.ide.setFileOpen(rif.filepath) +              full_file_contents = await sdk.ide.readFile(rif.filepath) -            segs = full_file_contents.split(rif.contents) +            start_index, end_index = rif.range.indices_in_string( +                full_file_contents) +            segs = [full_file_contents[:start_index], +                    full_file_contents[end_index:]] +              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) -            completion = str(sdk.models.default.complete(prompt)) +            completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))              eot_token = "<|endoftext|>"              completion = completion.removesuffix(eot_token)              self._prompt_and_completion += prompt + completion -            # Calculate diff, open file, apply edits, and highlight changed lines -            edits = calculate_diff2( -                rif.filepath, rif.contents, completion.removesuffix("\n")) - -            await sdk.ide.setFileOpen(rif.filepath) - -            lines_to_highlight = set() -            for edit in edits: -                edit.range.start.line += rif.range.start.line -                edit.range.start.character += rif.range.start.character -                edit.range.end.line += rif.range.start.line -                edit.range.end.character += rif.range.start.character if edit.range.end.line == 0 else 0 - -                for line in range(edit.range.start.line, edit.range.end.line + 1 + len(edit.replacement.splitlines()) - (edit.range.end.line - edit.range.start.line + 1)): -                    lines_to_highlight.add(line) - -                await sdk.ide.applyFileSystemEdit(edit) - -            current_start = None -            last_line = None -            for line in sorted(list(lines_to_highlight)): -                if current_start is None: -                    current_start = line -                elif line != last_line + 1: -                    await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) -                    current_start = line - -                last_line = line - -            if current_start is not None: -                await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) +            diff = list(difflib.ndiff(rif.contents.splitlines( +                keepends=True), completion.splitlines(keepends=True))) + +            lines_to_highlight = [] +            index = 0 +            for line in diff: +                if line.startswith("-"): +                    pass +                elif line.startswith("+"): +                    lines_to_highlight.append(index + rif.range.start.line) +                    index += 1 +                elif line.startswith(" "): +                    index += 1 + + +            await sdk.ide.applyFileSystemEdit(FileEdit( +                filepath=rif.filepath, +                range=rif.range, +                replacement=completion +            )) + +            for line in lines_to_highlight: +                await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand(line, 0, line, 0)))              await sdk.ide.saveFile(rif.filepath) | 
