summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-01 12:36:57 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-01 12:36:57 -0700
commit36577b8e94809da47a540499132774a0fe2c085d (patch)
treee912172745fedf947b0f393ceaa2d36aaa703626 /continuedev/src
parent95ce61f2655dcbeb4fed019b6a9d8a632bad7adc (diff)
downloadsncontinue-36577b8e94809da47a540499132774a0fe2c085d.tar.gz
sncontinue-36577b8e94809da47a540499132774a0fe2c085d.tar.bz2
sncontinue-36577b8e94809da47a540499132774a0fe2c085d.zip
explicit context pill buttons
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py40
-rw-r--r--continuedev/src/continuedev/core/main.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py7
-rw-r--r--continuedev/src/continuedev/models/filesystem.py29
-rw-r--r--continuedev/src/continuedev/models/main.py6
-rw-r--r--continuedev/src/continuedev/server/gui.py7
-rw-r--r--continuedev/src/continuedev/server/ide.py10
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py6
-rw-r--r--continuedev/src/continuedev/steps/core/core.py3
9 files changed, 100 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 3c7fbdef..b9e61c63 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -2,8 +2,10 @@ from functools import cached_property
import traceback
import time
from typing import Any, Callable, Coroutine, Dict, List
-
+import os
from aiohttp import ClientPayloadError
+
+from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from ..libs.llm import LLM
from .observation import Observation, InternalErrorObservation
@@ -59,7 +61,13 @@ class Autopilot(ContinueBaseModel):
keep_untouched = (cached_property,)
def get_full_state(self) -> FullState:
- return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, default_model=self.continue_sdk.config.default_model)
+ return FullState(
+ history=self.history,
+ active=self._active,
+ user_input_queue=self._main_user_input_queue,
+ default_model=self.continue_sdk.config.default_model,
+ highlighted_ranges=self._highlighted_ranges
+ )
async def get_available_slash_commands(self) -> List[Dict]:
return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or []
@@ -124,6 +132,30 @@ class Autopilot(ContinueBaseModel):
tb_step.step_name, {"output": output, **tb_step.params})
await self._run_singular_step(step)
+ _highlighted_ranges: List[RangeInFileWithContents] = []
+
+ async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
+ workspace_path = self.continue_sdk.ide.workspace_directory
+ for rif in range_in_files:
+ rif.filepath = os.path.relpath(rif.filepath, workspace_path)
+
+ new_ranges = []
+ for rif in range_in_files:
+ found_overlap = False
+ for i in range(len(self._highlighted_ranges)):
+ hr = self._highlighted_ranges[i]
+ if hr.filepath == rif.filepath and hr.range.overlaps_with(rif.range):
+ new_ranges.append(rif.union(hr))
+ found_overlap = True
+ self._highlighted_ranges.pop(i)
+ break
+
+ if not found_overlap:
+ new_ranges.append(rif)
+
+ self._highlighted_ranges += new_ranges
+ await self.update_subscribers()
+
_step_depth: int = 0
async def retry_at_index(self, index: int):
@@ -135,6 +167,10 @@ class Autopilot(ContinueBaseModel):
self.history.timeline[index].deleted = True
await self.update_subscribers()
+ async def delete_context_item_at_index(self, index: int):
+ self._highlighted_ranges.pop(index)
+ await self.update_subscribers()
+
async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
# Allow config to set disallowed steps
if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps:
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 4c6f4dc2..2d84801c 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -2,6 +2,7 @@ import json
from textwrap import dedent
from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union
+from ..models.filesystem import RangeInFileWithContents
from ..models.main import ContinueBaseModel
from pydantic import validator
from .observation import Observation
@@ -201,6 +202,7 @@ class FullState(ContinueBaseModel):
active: bool
user_input_queue: List[str]
default_model: str
+ highlighted_ranges: List[RangeInFileWithContents]
class ContinueSDK:
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index d95a233f..632f8683 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -179,7 +179,7 @@ class ContinueSDK(AbstractContinueSDK):
async def get_chat_context(self) -> List[ChatMessage]:
history_context = self.history.to_chat_history()
- highlighted_code = await self.ide.getHighlightedCode()
+ highlighted_code = self.__autopilot._highlighted_ranges
preface = "The following code is highlighted"
@@ -190,11 +190,10 @@ class ContinueSDK(AbstractContinueSDK):
if len(files) > 0:
content = await self.ide.readFile(files[0])
highlighted_code = [
- RangeInFile.from_entire_file(files[0], content)]
+ RangeInFileWithContents.from_entire_file(files[0], content)]
for rif in highlighted_code:
- code = await self.ide.readRangeInFile(rif)
- msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{code}\n```",
+ msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",
role="system", summary=f"{preface}: {rif.filepath}")
# Don't insert after latest user message or function call
diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py
index b709dd21..fc1c3f13 100644
--- a/continuedev/src/continuedev/models/filesystem.py
+++ b/continuedev/src/continuedev/models/filesystem.py
@@ -23,11 +23,40 @@ class RangeInFile(BaseModel):
class RangeInFileWithContents(RangeInFile):
+ """A range in a file with the contents of the range."""
contents: str
def __hash__(self):
return hash((self.filepath, self.range, self.contents))
+ def union(self, other: "RangeInFileWithContents") -> "RangeInFileWithContents":
+ assert self.filepath == other.filepath
+ # Use a placeholder variable for self and swap it with other if other comes before self
+ first = self
+ second = other
+ if other.range.start < self.range.start:
+ first = other
+ second = self
+
+ assert first.filepath == second.filepath
+
+ # Calculate the start and end positions of the overlap
+ overlap_start = max(first.range.start,
+ second.range.start) - first.range.start
+ overlap_end = min(first.range.end, second.range.end) - \
+ first.range.start
+
+ # Calculate the new contents by removing the overlap
+ union_contents = first.contents[:overlap_start] + \
+ second.contents[overlap_start:overlap_end] + \
+ first.contents[overlap_end:]
+
+ return RangeInFileWithContents(
+ filepath=first.filepath,
+ range=first.range.union(second.range),
+ contents=union_contents
+ )
+
@staticmethod
def from_entire_file(filepath: str, content: str) -> "RangeInFileWithContents":
lines = content.splitlines()
diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py
index d5f6e650..101be4ae 100644
--- a/continuedev/src/continuedev/models/main.py
+++ b/continuedev/src/continuedev/models/main.py
@@ -49,6 +49,12 @@ class Range(BaseModel):
start: Position
end: Position
+ def __lt__(self, other: "Range") -> bool:
+ return self.start < other.start or (self.start == other.start and self.end < other.end)
+
+ def __eq__(self, other: "Range") -> bool:
+ return self.start == other.start and self.end == other.end
+
def __hash__(self):
return hash((self.start, self.end))
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index c0178920..9a33fb6c 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -83,6 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_clear_history()
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
+ elif message_type == "delete_context_item_at_index":
+ self.on_delete_context_item_at_index(data["index"])
except Exception as e:
print(e)
@@ -127,6 +129,11 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_delete_at_index(self, index: int):
asyncio.create_task(self.session.autopilot.delete_at_index(index))
+ def on_delete_context_item_at_index(self, index: int):
+ asyncio.create_task(
+ self.session.autopilot.delete_context_item_at_index(index)
+ )
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e1f19447..f3deecdb 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -8,7 +8,7 @@ from fastapi import WebSocket, Body, APIRouter
from uvicorn.main import Server
from ..libs.util.queue import AsyncSubscriptionQueue
-from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RealFileSystem
+from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem
from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit
from pydantic import BaseModel
from .gui import SessionManager, session_manager
@@ -139,6 +139,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
fileEdits = list(
map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))
self.onFileEdits(fileEdits)
+ elif message_type == "highlightedCodePush":
+ self.onHighlightedCodeUpdate(
+ [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]])
elif message_type == "commandOutput":
output = data["output"]
self.onCommandOutput(output)
@@ -229,6 +232,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
asyncio.create_task(
session.autopilot.handle_command_output(output))
+ def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
+ for _, session in self.session_manager.sessions.items():
+ asyncio.create_task(
+ session.autopilot.handle_highlighted_code(range_in_files))
+
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index d2dafa9a..17a09c3d 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod, abstractproperty
from ..models.main import Traceback
from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff
-from ..models.filesystem import RangeInFile
+from ..models.filesystem import RangeInFile, RangeInFileWithContents
class AbstractIdeProtocolServer(ABC):
@@ -91,6 +91,10 @@ class AbstractIdeProtocolServer(ABC):
async def runCommand(self, command: str) -> str:
"""Run a command"""
+ @abstractmethod
+ def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
+ """Called when highlighted code is updated"""
+
@abstractproperty
def workspace_directory(self) -> str:
"""Get the workspace directory"""
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index a84263cc..729f5e66 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -311,8 +311,7 @@ class DefaultModelEditCodeStep(Step):
nonlocal current_block_start, current_line_in_file, original_lines, original_lines_below_previous_blocks, current_block_lines, indices_of_last_matched_lines, LINES_TO_MATCH_BEFORE_ENDING_BLOCK, offset_from_blocks
# Highlight the line to show progress
- # - len(current_block_lines)
- line_to_highlight = current_line_in_file
+ line_to_highlight = current_line_in_file - len(current_block_lines)
await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand(
line_to_highlight, 0, line_to_highlight, 0)), "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022")