summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py2
-rw-r--r--continuedev/src/continuedev/models/main.py14
-rw-r--r--continuedev/src/continuedev/steps/core/core.py70
-rw-r--r--extension/src/continueIdeClient.ts22
4 files changed, 64 insertions, 44 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 39c0b69f..6f620ba0 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -34,7 +34,7 @@ class OpenAI(LLM):
return tiktoken.encoding_for_model(self.default_model)
def count_tokens(self, text: str):
- return len(self.__encoding_for_model.encode(text))
+ return len(self.__encoding_for_model.encode(text, disallowed_special=()))
def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
tokens = tokens_for_completion
diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py
index 02c44aae..fceba284 100644
--- a/continuedev/src/continuedev/models/main.py
+++ b/continuedev/src/continuedev/models/main.py
@@ -1,5 +1,5 @@
from abc import ABC
-from typing import List, Union
+from typing import List, Union, Tuple
from pydantic import BaseModel, root_validator
from functools import total_ordering
@@ -61,6 +61,18 @@ class Range(BaseModel):
def is_empty(self) -> bool:
return self.start == self.end
+ def indices_in_string(self, string: str) -> Tuple[int, int]:
+ """Get the start and end indicees of this range in the string"""
+ lines = string.splitlines()
+ if len(lines) == 0:
+ return (0, 0)
+
+ start_index = sum(
+ [len(line) + 1 for line in lines[:self.start.line]]) + self.start.character
+ end_index = sum(
+ [len(line) + 1 for line in lines[:self.end.line]]) + self.end.character
+ return (start_index, end_index)
+
def overlaps_with(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index ff498b9b..f1fb229e 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -5,12 +5,13 @@ from textwrap import dedent
from typing import Coroutine, List, Union
from ...models.main import Range
-from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str
+from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str, line_by_line_diff
from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ...core.main import Step, SequentialStep
+import difflib
class ContinueSDK:
@@ -135,7 +136,7 @@ class DefaultModelEditCodeStep(Step):
return a + b
<|endoftext|>
- Now complete the real thing. Do NOT rewrite the prefix or suffix.
+ Now complete the real thing. Do NOT rewrite the prefix or suffix. You are only to write the code that goes in "commit_after".
<file_prefix>
{file_prefix}
@@ -159,6 +160,7 @@ class DefaultModelEditCodeStep(Step):
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
self.name = self.user_input
+ await sdk.update_ui()
rif_with_contents = []
for range_in_file in self.range_in_files:
@@ -171,48 +173,46 @@ class DefaultModelEditCodeStep(Step):
rif_dict[rif.filepath] = rif.contents
for rif in rif_with_contents:
+ await sdk.ide.setFileOpen(rif.filepath)
+
full_file_contents = await sdk.ide.readFile(rif.filepath)
- segs = full_file_contents.split(rif.contents)
+ start_index, end_index = rif.range.indices_in_string(
+ full_file_contents)
+ segs = [full_file_contents[:start_index],
+ full_file_contents[end_index:]]
+
prompt = self._prompt.format(
code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1])
- completion = str(sdk.models.default.complete(prompt))
+ completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))
eot_token = "<|endoftext|>"
completion = completion.removesuffix(eot_token)
self._prompt_and_completion += prompt + completion
- # Calculate diff, open file, apply edits, and highlight changed lines
- edits = calculate_diff2(
- rif.filepath, rif.contents, completion.removesuffix("\n"))
-
- await sdk.ide.setFileOpen(rif.filepath)
-
- lines_to_highlight = set()
- for edit in edits:
- edit.range.start.line += rif.range.start.line
- edit.range.start.character += rif.range.start.character
- edit.range.end.line += rif.range.start.line
- edit.range.end.character += rif.range.start.character if edit.range.end.line == 0 else 0
-
- for line in range(edit.range.start.line, edit.range.end.line + 1 + len(edit.replacement.splitlines()) - (edit.range.end.line - edit.range.start.line + 1)):
- lines_to_highlight.add(line)
-
- await sdk.ide.applyFileSystemEdit(edit)
-
- current_start = None
- last_line = None
- for line in sorted(list(lines_to_highlight)):
- if current_start is None:
- current_start = line
- elif line != last_line + 1:
- await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0)))
- current_start = line
-
- last_line = line
-
- if current_start is not None:
- await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0)))
+ diff = list(difflib.ndiff(rif.contents.splitlines(
+ keepends=True), completion.splitlines(keepends=True)))
+
+ lines_to_highlight = []
+ index = 0
+ for line in diff:
+ if line.startswith("-"):
+ pass
+ elif line.startswith("+"):
+ lines_to_highlight.append(index + rif.range.start.line)
+ index += 1
+ elif line.startswith(" "):
+ index += 1
+
+
+ await sdk.ide.applyFileSystemEdit(FileEdit(
+ filepath=rif.filepath,
+ range=rif.range,
+ replacement=completion
+ ))
+
+ for line in lines_to_highlight:
+ await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand(line, 0, line, 0)))
await sdk.ide.saveFile(rif.filepath)
diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts
index 92af6b10..782219dc 100644
--- a/extension/src/continueIdeClient.ts
+++ b/extension/src/continueIdeClient.ts
@@ -140,12 +140,20 @@ class IdeProtocolClient {
vscode.ViewColumn.One
);
if (editor) {
- editor.setDecorations(
- vscode.window.createTextEditorDecorationType({
- backgroundColor: color,
- isWholeLine: true,
- }),
- [range]
+ const decorationType = vscode.window.createTextEditorDecorationType({
+ backgroundColor: color,
+ isWholeLine: true,
+ });
+ editor.setDecorations(decorationType, [range]);
+
+ // Listen for changes to cursor position
+ const cursorDisposable = vscode.window.onDidChangeTextEditorSelection(
+ (event) => {
+ if (event.textEditor.document.uri.fsPath === rangeInFile.filepath) {
+ cursorDisposable.dispose();
+ editor.setDecorations(decorationType, []);
+ }
+ }
);
}
}
@@ -282,7 +290,7 @@ class IdeProtocolClient {
edit.range.start.line,
edit.range.start.character,
edit.range.end.line,
- edit.range.end.character
+ edit.range.end.character + 1
);
editor.edit((editBuilder) => {