From 36577b8e94809da47a540499132774a0fe2c085d Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sat, 1 Jul 2023 12:36:57 -0700 Subject: explicit context pill buttons --- continuedev/src/continuedev/core/autopilot.py | 40 ++++++++++++++++++++-- continuedev/src/continuedev/core/main.py | 2 ++ continuedev/src/continuedev/core/sdk.py | 7 ++-- continuedev/src/continuedev/models/filesystem.py | 29 ++++++++++++++++ continuedev/src/continuedev/models/main.py | 6 ++++ continuedev/src/continuedev/server/gui.py | 7 ++++ continuedev/src/continuedev/server/ide.py | 10 +++++- continuedev/src/continuedev/server/ide_protocol.py | 6 +++- continuedev/src/continuedev/steps/core/core.py | 3 +- 9 files changed, 100 insertions(+), 10 deletions(-) (limited to 'continuedev') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3c7fbdef..b9e61c63 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -2,8 +2,10 @@ from functools import cached_property import traceback import time from typing import Any, Callable, Coroutine, Dict, List - +import os from aiohttp import ClientPayloadError + +from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from ..libs.llm import LLM from .observation import Observation, InternalErrorObservation @@ -59,7 +61,13 @@ class Autopilot(ContinueBaseModel): keep_untouched = (cached_property,) def get_full_state(self) -> FullState: - return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, default_model=self.continue_sdk.config.default_model) + return FullState( + history=self.history, + active=self._active, + user_input_queue=self._main_user_input_queue, + default_model=self.continue_sdk.config.default_model, + highlighted_ranges=self._highlighted_ranges + ) async def get_available_slash_commands(self) -> List[Dict]: return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] @@ -124,6 +132,30 @@ class Autopilot(ContinueBaseModel): tb_step.step_name, {"output": output, **tb_step.params}) await self._run_singular_step(step) + _highlighted_ranges: List[RangeInFileWithContents] = [] + + async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): + workspace_path = self.continue_sdk.ide.workspace_directory + for rif in range_in_files: + rif.filepath = os.path.relpath(rif.filepath, workspace_path) + + new_ranges = [] + for rif in range_in_files: + found_overlap = False + for i in range(len(self._highlighted_ranges)): + hr = self._highlighted_ranges[i] + if hr.filepath == rif.filepath and hr.range.overlaps_with(rif.range): + new_ranges.append(rif.union(hr)) + found_overlap = True + self._highlighted_ranges.pop(i) + break + + if not found_overlap: + new_ranges.append(rif) + + self._highlighted_ranges += new_ranges + await self.update_subscribers() + _step_depth: int = 0 async def retry_at_index(self, index: int): @@ -135,6 +167,10 @@ class Autopilot(ContinueBaseModel): self.history.timeline[index].deleted = True await self.update_subscribers() + async def delete_context_item_at_index(self, index: int): + self._highlighted_ranges.pop(index) + await self.update_subscribers() + async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: # Allow config to set disallowed steps if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 4c6f4dc2..2d84801c 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -2,6 +2,7 @@ import json from textwrap import dedent from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union +from ..models.filesystem import RangeInFileWithContents from ..models.main import ContinueBaseModel from pydantic import validator from .observation import Observation @@ -201,6 +202,7 @@ class FullState(ContinueBaseModel): active: bool user_input_queue: List[str] default_model: str + highlighted_ranges: List[RangeInFileWithContents] class ContinueSDK: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d95a233f..632f8683 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -179,7 +179,7 @@ class ContinueSDK(AbstractContinueSDK): async def get_chat_context(self) -> List[ChatMessage]: history_context = self.history.to_chat_history() - highlighted_code = await self.ide.getHighlightedCode() + highlighted_code = self.__autopilot._highlighted_ranges preface = "The following code is highlighted" @@ -190,11 +190,10 @@ class ContinueSDK(AbstractContinueSDK): if len(files) > 0: content = await self.ide.readFile(files[0]) highlighted_code = [ - RangeInFile.from_entire_file(files[0], content)] + RangeInFileWithContents.from_entire_file(files[0], content)] for rif in highlighted_code: - code = await self.ide.readRangeInFile(rif) - msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", + msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", role="system", summary=f"{preface}: {rif.filepath}") # Don't insert after latest user message or function call diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index b709dd21..fc1c3f13 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -23,11 +23,40 @@ class RangeInFile(BaseModel): class RangeInFileWithContents(RangeInFile): + """A range in a file with the contents of the range.""" contents: str def __hash__(self): return hash((self.filepath, self.range, self.contents)) + def union(self, other: "RangeInFileWithContents") -> "RangeInFileWithContents": + assert self.filepath == other.filepath + # Use a placeholder variable for self and swap it with other if other comes before self + first = self + second = other + if other.range.start < self.range.start: + first = other + second = self + + assert first.filepath == second.filepath + + # Calculate the start and end positions of the overlap + overlap_start = max(first.range.start, + second.range.start) - first.range.start + overlap_end = min(first.range.end, second.range.end) - \ + first.range.start + + # Calculate the new contents by removing the overlap + union_contents = first.contents[:overlap_start] + \ + second.contents[overlap_start:overlap_end] + \ + first.contents[overlap_end:] + + return RangeInFileWithContents( + filepath=first.filepath, + range=first.range.union(second.range), + contents=union_contents + ) + @staticmethod def from_entire_file(filepath: str, content: str) -> "RangeInFileWithContents": lines = content.splitlines() diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index d5f6e650..101be4ae 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -49,6 +49,12 @@ class Range(BaseModel): start: Position end: Position + def __lt__(self, other: "Range") -> bool: + return self.start < other.start or (self.start == other.start and self.end < other.end) + + def __eq__(self, other: "Range") -> bool: + return self.start == other.start and self.end == other.end + def __hash__(self): return hash((self.start, self.end)) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index c0178920..9a33fb6c 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -83,6 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_clear_history() elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) + elif message_type == "delete_context_item_at_index": + self.on_delete_context_item_at_index(data["index"]) except Exception as e: print(e) @@ -127,6 +129,11 @@ class GUIProtocolServer(AbstractGUIProtocolServer): def on_delete_at_index(self, index: int): asyncio.create_task(self.session.autopilot.delete_at_index(index)) + def on_delete_context_item_at_index(self, index: int): + asyncio.create_task( + self.session.autopilot.delete_context_item_at_index(index) + ) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e1f19447..f3deecdb 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -8,7 +8,7 @@ from fastapi import WebSocket, Body, APIRouter from uvicorn.main import Server from ..libs.util.queue import AsyncSubscriptionQueue -from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RealFileSystem +from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit from pydantic import BaseModel from .gui import SessionManager, session_manager @@ -139,6 +139,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer): fileEdits = list( map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])) self.onFileEdits(fileEdits) + elif message_type == "highlightedCodePush": + self.onHighlightedCodeUpdate( + [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]]) elif message_type == "commandOutput": output = data["output"] self.onCommandOutput(output) @@ -229,6 +232,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer): asyncio.create_task( session.autopilot.handle_command_output(output)) + def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + for _, session in self.session_manager.sessions.items(): + asyncio.create_task( + session.autopilot.handle_highlighted_code(range_in_files)) + # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index d2dafa9a..17a09c3d 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod, abstractproperty from ..models.main import Traceback from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff -from ..models.filesystem import RangeInFile +from ..models.filesystem import RangeInFile, RangeInFileWithContents class AbstractIdeProtocolServer(ABC): @@ -91,6 +91,10 @@ class AbstractIdeProtocolServer(ABC): async def runCommand(self, command: str) -> str: """Run a command""" + @abstractmethod + def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + """Called when highlighted code is updated""" + @abstractproperty def workspace_directory(self) -> str: """Get the workspace directory""" diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index a84263cc..729f5e66 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -311,8 +311,7 @@ class DefaultModelEditCodeStep(Step): nonlocal current_block_start, current_line_in_file, original_lines, original_lines_below_previous_blocks, current_block_lines, indices_of_last_matched_lines, LINES_TO_MATCH_BEFORE_ENDING_BLOCK, offset_from_blocks # Highlight the line to show progress - # - len(current_block_lines) - line_to_highlight = current_line_in_file + line_to_highlight = current_line_in_file - len(current_block_lines) await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand( line_to_highlight, 0, line_to_highlight, 0)), "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022") -- cgit v1.2.3-70-g09d2 From a606c13ca75f0c9177b3d04f20dcf7211d81f083 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 2 Jul 2023 20:14:27 -0700 Subject: finishing up explicit context --- continuedev/src/continuedev/core/autopilot.py | 17 ++--- continuedev/src/continuedev/models/filesystem.py | 16 ++--- continuedev/src/continuedev/models/main.py | 5 ++ continuedev/src/continuedev/steps/chat.py | 2 +- extension/react-app/src/components/ComboBox.tsx | 83 ++++++++++++---------- .../react-app/src/components/ContinueButton.tsx | 1 + extension/react-app/src/components/PillButton.tsx | 66 +++++++++++++++++ 7 files changed, 131 insertions(+), 59 deletions(-) create mode 100644 extension/react-app/src/components/PillButton.tsx (limited to 'continuedev') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index b9e61c63..1a77ca64 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -139,21 +139,22 @@ class Autopilot(ContinueBaseModel): for rif in range_in_files: rif.filepath = os.path.relpath(rif.filepath, workspace_path) + old_ranges = self._highlighted_ranges + range_in_files new_ranges = [] - for rif in range_in_files: + + while len(old_ranges) > 0: + old_range = old_ranges.pop(0) found_overlap = False - for i in range(len(self._highlighted_ranges)): - hr = self._highlighted_ranges[i] - if hr.filepath == rif.filepath and hr.range.overlaps_with(rif.range): - new_ranges.append(rif.union(hr)) + for i in range(len(new_ranges)): + if old_range.filepath == new_ranges[i].filepath and old_range.range.overlaps_with(new_ranges[i].range): + new_ranges[i] = old_range.union(new_ranges[i]) found_overlap = True - self._highlighted_ranges.pop(i) break if not found_overlap: - new_ranges.append(rif) + new_ranges.append(old_range) - self._highlighted_ranges += new_ranges + self._highlighted_ranges = new_ranges await self.update_subscribers() _step_depth: int = 0 diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index fc1c3f13..df0b15d7 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -40,21 +40,15 @@ class RangeInFileWithContents(RangeInFile): assert first.filepath == second.filepath - # Calculate the start and end positions of the overlap - overlap_start = max(first.range.start, - second.range.start) - first.range.start - overlap_end = min(first.range.end, second.range.end) - \ - first.range.start - - # Calculate the new contents by removing the overlap - union_contents = first.contents[:overlap_start] + \ - second.contents[overlap_start:overlap_end] + \ - first.contents[overlap_end:] + # Calculate union of contents + num_overlapping_lines = first.range.end.line - second.range.start.line + 1 + union_lines = first.contents.splitlines()[:-num_overlapping_lines] + \ + second.contents.splitlines() return RangeInFileWithContents( filepath=first.filepath, range=first.range.union(second.range), - contents=union_contents + contents="\n".join(union_lines) ) @staticmethod diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 101be4ae..fa736772 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -43,6 +43,11 @@ class Position(BaseModel): def from_end_of_file(contents: str) -> "Position": return Position.from_index(contents, len(contents)) + def to_index(self, string: str) -> int: + """Convert line and character to index in string""" + lines = string.splitlines() + return sum(map(len, lines[:self.line])) + self.character + class Range(BaseModel): """A range in a file. 0-indexed.""" diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 8494563b..b10ec3d7 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -106,7 +106,7 @@ class RunTerminalCommandStep(Step): class ViewDirectoryTreeStep(Step): name: str = "View Directory Tree" - description: str = "View the directory tree to learn which folder and files exist." + description: str = "View the directory tree to learn which folder and files exist. You should always do this before adding new files." async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: return f"Viewed the directory tree." diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index 34027a42..f299c3a2 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -9,6 +9,7 @@ import { } from "."; import CodeBlock from "./CodeBlock"; import { RangeInFile } from "../../../src/client"; +import PillButton from "./PillButton"; const mainInputFontSize = 16; @@ -22,27 +23,9 @@ const ContextDropdown = styled.div` border-bottom-left-radius: ${defaultBorderRadius}; /* border: 1px solid white; */ border-top: none; - margin-left: 8px; - margin-right: 8px; - margin-top: -12px; + margin: 8px; outline: 1px solid orange; -`; - -const PillButton = styled.button` - display: flex; - justify-content: space-between; - align-items: center; - border: none; - color: white; - background-color: gray; - border-radius: 50px; - padding: 5px 10px; - margin: 5px 0; - cursor: pointer; - - &:hover { - background-color: ${buttonColor}; - } + z-index: 5; `; const MainTextInput = styled.textarea` @@ -118,7 +101,9 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { // The position of the current command you are typing now, so the one that will be appended to history once you press enter const [positionInHistory, setPositionInHistory] = React.useState(0); const [items, setItems] = React.useState(props.items); - const [showContextDropdown, setShowContextDropdown] = React.useState(false); + const [hoveringButton, setHoveringButton] = React.useState(false); + const [hoveringContextDropdown, setHoveringContextDropdown] = + React.useState(false); const [highlightedCodeSections, setHighlightedCodeSections] = React.useState( props.highlightedCodeSections || [ { @@ -184,7 +169,7 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { 300 ).toString()}px`; - setShowContextDropdown(target.value.endsWith("@")); + // setShowContextDropdown(target.value.endsWith("@")); }, onKeyDown: (event) => { if (event.key === "Enter" && event.shiftKey) { @@ -256,22 +241,11 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { ))} - -
+
{highlightedCodeSections.map((section, idx) => ( { - console.log("delete context item", idx); + title={section.filepath} + onDelete={() => { if (props.deleteContextItem) { props.deleteContextItem(idx); } @@ -281,11 +255,42 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { return newSections; }); }} - > - {section.filepath} - + onHover={(val: boolean) => { + if (val) { + setHoveringButton(val); + } else { + setTimeout(() => { + setHoveringButton(val); + }, 100); + } + }} + /> ))} + + + Highlight code to include as context.{" "} + {highlightedCodeSections.length > 0 && + "Otherwise using entire currently open file."} +
+ { + setHoveringContextDropdown(true); + }} + onMouseLeave={() => { + setHoveringContextDropdown(false); + }} + hidden={!hoveringContextDropdown && !hoveringButton} + > + {highlightedCodeSections.map((section, idx) => ( + <> +

{section.filepath}

+ + {section.contents} + + + ))} +
); }); diff --git a/extension/react-app/src/components/ContinueButton.tsx b/extension/react-app/src/components/ContinueButton.tsx index ef6719b7..5295799a 100644 --- a/extension/react-app/src/components/ContinueButton.tsx +++ b/extension/react-app/src/components/ContinueButton.tsx @@ -6,6 +6,7 @@ import { RootStore } from "../redux/store"; let StyledButton = styled(Button)` margin: auto; + margin-top: 8px; display: grid; grid-template-columns: 30px 1fr; align-items: center; diff --git a/extension/react-app/src/components/PillButton.tsx b/extension/react-app/src/components/PillButton.tsx new file mode 100644 index 00000000..33451db5 --- /dev/null +++ b/extension/react-app/src/components/PillButton.tsx @@ -0,0 +1,66 @@ +import { useState } from "react"; +import styled from "styled-components"; +import { defaultBorderRadius } from "."; + +const Button = styled.button` + border: none; + color: white; + background-color: transparent; + border: 1px solid white; + border-radius: ${defaultBorderRadius}; + padding: 3px 6px; + + &:hover { + background-color: white; + color: black; + } +`; + +interface PillButtonProps { + onHover?: (arg0: boolean) => void; + onDelete?: () => void; + title: string; +} + +const PillButton = (props: PillButtonProps) => { + const [isHovered, setIsHovered] = useState(false); + return ( + + ); +}; + +export default PillButton; -- cgit v1.2.3-70-g09d2