diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/main.py | 14 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 70 | ||||
-rw-r--r-- | extension/src/continueIdeClient.ts | 22 |
4 files changed, 64 insertions, 44 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 39c0b69f..6f620ba0 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -34,7 +34,7 @@ class OpenAI(LLM): return tiktoken.encoding_for_model(self.default_model) def count_tokens(self, text: str): - return len(self.__encoding_for_model.encode(text)) + return len(self.__encoding_for_model.encode(text, disallowed_special=())) def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): tokens = tokens_for_completion diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 02c44aae..fceba284 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Union +from typing import List, Union, Tuple from pydantic import BaseModel, root_validator from functools import total_ordering @@ -61,6 +61,18 @@ class Range(BaseModel): def is_empty(self) -> bool: return self.start == self.end + def indices_in_string(self, string: str) -> Tuple[int, int]: + """Get the start and end indicees of this range in the string""" + lines = string.splitlines() + if len(lines) == 0: + return (0, 0) + + start_index = sum( + [len(line) + 1 for line in lines[:self.start.line]]) + self.start.character + end_index = sum( + [len(line) + 1 for line in lines[:self.end.line]]) + self.end.character + return (start_index, end_index) + def overlaps_with(self, other: "Range") -> bool: return not (self.end < other.start or self.start > other.end) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index ff498b9b..f1fb229e 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -5,12 +5,13 @@ from textwrap import dedent from typing import Coroutine, List, Union from ...models.main import Range -from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str +from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str, line_by_line_diff from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation from ...core.main import Step, SequentialStep +import difflib class ContinueSDK: @@ -135,7 +136,7 @@ class DefaultModelEditCodeStep(Step): return a + b <|endoftext|> - Now complete the real thing. Do NOT rewrite the prefix or suffix. + Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after". <file_prefix> {file_prefix} @@ -159,6 +160,7 @@ class DefaultModelEditCodeStep(Step): async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: self.name = self.user_input + await sdk.update_ui() rif_with_contents = [] for range_in_file in self.range_in_files: @@ -171,48 +173,46 @@ class DefaultModelEditCodeStep(Step): rif_dict[rif.filepath] = rif.contents for rif in rif_with_contents: + await sdk.ide.setFileOpen(rif.filepath) + full_file_contents = await sdk.ide.readFile(rif.filepath) - segs = full_file_contents.split(rif.contents) + start_index, end_index = rif.range.indices_in_string( + full_file_contents) + segs = [full_file_contents[:start_index], + full_file_contents[end_index:]] + prompt = self._prompt.format( code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) - completion = str(sdk.models.default.complete(prompt)) + completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) eot_token = "<|endoftext|>" completion = completion.removesuffix(eot_token) self._prompt_and_completion += prompt + completion - # Calculate diff, open file, apply edits, and highlight changed lines - edits = calculate_diff2( - rif.filepath, rif.contents, completion.removesuffix("\n")) - - await sdk.ide.setFileOpen(rif.filepath) - - lines_to_highlight = set() - for edit in edits: - edit.range.start.line += rif.range.start.line - edit.range.start.character += rif.range.start.character - edit.range.end.line += rif.range.start.line - edit.range.end.character += rif.range.start.character if edit.range.end.line == 0 else 0 - - for line in range(edit.range.start.line, edit.range.end.line + 1 + len(edit.replacement.splitlines()) - (edit.range.end.line - edit.range.start.line + 1)): - lines_to_highlight.add(line) - - await sdk.ide.applyFileSystemEdit(edit) - - current_start = None - last_line = None - for line in sorted(list(lines_to_highlight)): - if current_start is None: - current_start = line - elif line != last_line + 1: - await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) - current_start = line - - last_line = line - - if current_start is not None: - await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) + diff = list(difflib.ndiff(rif.contents.splitlines( + keepends=True), completion.splitlines(keepends=True))) + + lines_to_highlight = [] + index = 0 + for line in diff: + if line.startswith("-"): + pass + elif line.startswith("+"): + lines_to_highlight.append(index + rif.range.start.line) + index += 1 + elif line.startswith(" "): + index += 1 + + + await sdk.ide.applyFileSystemEdit(FileEdit( + filepath=rif.filepath, + range=rif.range, + replacement=completion + )) + + for line in lines_to_highlight: + await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand(line, 0, line, 0))) await sdk.ide.saveFile(rif.filepath) diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts index 92af6b10..782219dc 100644 --- a/extension/src/continueIdeClient.ts +++ b/extension/src/continueIdeClient.ts @@ -140,12 +140,20 @@ class IdeProtocolClient { vscode.ViewColumn.One ); if (editor) { - editor.setDecorations( - vscode.window.createTextEditorDecorationType({ - backgroundColor: color, - isWholeLine: true, - }), - [range] + const decorationType = vscode.window.createTextEditorDecorationType({ + backgroundColor: color, + isWholeLine: true, + }); + editor.setDecorations(decorationType, [range]); + + // Listen for changes to cursor position + const cursorDisposable = vscode.window.onDidChangeTextEditorSelection( + (event) => { + if (event.textEditor.document.uri.fsPath === rangeInFile.filepath) { + cursorDisposable.dispose(); + editor.setDecorations(decorationType, []); + } + } ); } } @@ -282,7 +290,7 @@ class IdeProtocolClient { edit.range.start.line, edit.range.start.character, edit.range.end.line, - edit.range.end.character + edit.range.end.character + 1 ); editor.edit((editBuilder) => { |