From c98f860460767fe14f8fbf139150b1bd1ee2ff12 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 20 Aug 2023 20:02:07 -0700 Subject: feat: :sparkles: saved context groups --- continuedev/src/continuedev/core/autopilot.py | 54 ++ continuedev/src/continuedev/core/context.py | 18 + continuedev/src/continuedev/core/main.py | 1 + continuedev/src/continuedev/libs/util/paths.py | 9 + .../plugins/context_providers/highlighted_code.py | 19 + continuedev/src/continuedev/server/gui.py | 18 + extension/react-app/src/components/ComboBox.tsx | 195 +++++- extension/react-app/src/components/Layout.tsx | 4 +- extension/react-app/src/components/TextDialog.tsx | 42 +- .../src/hooks/AbstractContinueGUIClientProtocol.ts | 6 +- .../src/hooks/ContinueGUIClientProtocol.ts | 13 +- extension/schema/FullState.d.ts | 4 + schema/json/ContextItem.json | 150 ++--- schema/json/FileEdit.json | 124 ++-- schema/json/FileEditWithFullContents.json | 158 ++--- schema/json/FullState.json | 653 +++++++++++---------- schema/json/History.json | 320 +++++----- schema/json/HistoryNode.json | 276 ++++----- schema/json/Position.json | 46 +- schema/json/Range.json | 80 +-- schema/json/RangeInFile.json | 114 ++-- schema/json/SessionInfo.json | 56 +- schema/json/Traceback.json | 124 ++-- schema/json/TracebackFrame.json | 64 +- 24 files changed, 1407 insertions(+), 1141 deletions(-) diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index ded120d2..2e255198 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,3 +1,5 @@ +import json +import os import time import traceback from functools import cached_property @@ -9,6 +11,7 @@ from pydantic import root_validator from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger +from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes from ..libs.util.telemetry import posthog_logger @@ -30,6 +33,7 @@ from ..server.ide_protocol import AbstractIdeProtocolServer from .context import ContextManager from .main import ( Context, + ContextItem, ContinueCustomException, FullState, History, @@ -111,6 +115,21 @@ class Autopilot(ContinueBaseModel): self.history = full_state.history self.session_info = full_state.session_info + # Load saved context groups + context_groups_file = getSavedContextGroupsPath() + try: + with open(context_groups_file, "r") as f: + json_ob = json.load(f) + for title, context_group in json_ob.items(): + self._saved_context_groups[title] = [ + ContextItem(**item) for item in context_group + ] + except Exception as e: + logger.warning( + f"Failed to load saved_context_groups.json: {e}. Reverting to empty list." + ) + self._saved_context_groups = {} + self.started = True class Config: @@ -139,6 +158,7 @@ class Autopilot(ContinueBaseModel): if self.context_manager is not None else [], session_info=self.session_info, + saved_context_groups=self._saved_context_groups, ) self.full_state = full_state return full_state @@ -521,3 +541,37 @@ class Autopilot(ContinueBaseModel): async def select_context_item(self, id: str, query: str): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + + _saved_context_groups: Dict[str, List[ContextItem]] = {} + + async def save_context_group(self, title: str, context_items: List[ContextItem]): + self._saved_context_groups[title] = context_items + await self.update_subscribers() + + # Update saved context groups + context_groups_file = getSavedContextGroupsPath() + if os.path.exists(context_groups_file): + with open(context_groups_file, "w") as f: + dict_to_save = { + title: [item.dict() for item in context_items] + for title, context_items in self._saved_context_groups.items() + } + json.dump(dict_to_save, f) + + posthog_logger.capture_event( + "save_context_group", {"title": title, "length": len(context_items)} + ) + + async def select_context_group(self, id: str): + if id not in self._saved_context_groups: + logger.warning(f"Context group {id} not found") + return + context_group = self._saved_context_groups[id] + await self.context_manager.clear_context() + for item in context_group: + await self.context_manager.manually_add_context_item(item) + await self.update_subscribers() + + posthog_logger.capture_event( + "select_context_group", {"title": id, "length": len(context_group)} + ) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index e5d6e13b..7172883f 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -121,6 +121,13 @@ class ContextProvider(BaseModel): if new_item := await self.get_item(id, query): self.selected_items.append(new_item) + async def manually_add_context_item(self, context_item: ContextItem): + for item in self.selected_items: + if item.description.id.item_id == context_item.description.id.item_id: + return + + self.selected_items.append(context_item) + class ContextManager: """ @@ -278,6 +285,17 @@ class ContextManager: for provider in self.context_providers.values(): await self.context_providers[provider.title].clear_context() + async def manually_add_context_item(self, item: ContextItem): + """ + Adds the given ContextItem to the list of ContextItems. + """ + if item.description.id.provider_title not in self.provider_titles: + return + + await self.context_providers[ + item.description.id.provider_title + ].manually_add_context_item(item) + """ Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 53440dae..e4ee7668 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -287,6 +287,7 @@ class FullState(ContinueBaseModel): adding_highlighted_code: bool selected_context_items: List[ContextItem] session_info: Optional[SessionInfo] = None + saved_context_groups: Dict[str, List[ContextItem]] = {} class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 93ab16db..a411c5c3 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -75,3 +75,12 @@ def getLogFilePath(): path = os.path.join(getGlobalFolderPath(), "continue.log") os.makedirs(os.path.dirname(path), exist_ok=True) return path + + +def getSavedContextGroupsPath(): + path = os.path.join(getGlobalFolderPath(), "saved_context_groups.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, "w") as f: + f.write("\{\}") + return path diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index ed293124..504764b9 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -11,6 +11,7 @@ from ...core.context import ( ) from ...core.main import ChatMessage from ...models.filesystem import RangeInFileWithContents +from ...models.main import Range class HighlightedRangeContextItem(BaseModel): @@ -257,3 +258,21 @@ class HighlightedCodeContextProvider(ContextProvider): self, id: ContextItemId, query: str, prev: List[ContextItem] = None ) -> List[ContextItem]: raise NotImplementedError() + + async def manually_add_context_item(self, context_item: ContextItem): + full_file_content = await self.ide.readFile( + context_item.description.description + ) + self.highlighted_ranges.append( + HighlightedRangeContextItem( + rif=RangeInFileWithContents( + filepath=context_item.description.description, + range=Range.from_lines_snippet_in_file( + content=full_file_content, + snippet=context_item.content, + ), + contents=context_item.content, + ), + item=context_item, + ) + ) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 7497e777..a4c45a06 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server +from ..core.main import ContextItem from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue @@ -104,6 +105,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": self.edit_step_at_index(data.get("user_input", ""), data["index"]) + elif message_type == "save_context_group": + self.save_context_group( + data["title"], [ContextItem(**item) for item in data["context_items"]] + ) + elif message_type == "select_context_group": + self.select_context_group(data["id"]) def on_main_input(self, input: str): # Do something with user input @@ -186,6 +193,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer): posthog_logger.capture_event("load_session", {"session_id": session_id}) + def save_context_group(self, title: str, context_items: List[ContextItem]): + create_async_task( + self.session.autopilot.save_context_group(title, context_items), + self.on_error, + ) + + def select_context_group(self, id: str): + create_async_task( + self.session.autopilot.select_context_group(id), self.on_error + ) + @router.websocket("/ws") async def websocket_endpoint( diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index 0cf2bc19..c407a779 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -7,6 +7,8 @@ import React, { import { useCombobox } from "downshift"; import styled from "styled-components"; import { + Button, + TextInput, defaultBorderRadius, lightGray, secondaryDark, @@ -15,12 +17,20 @@ import { } from "."; import PillButton from "./PillButton"; import HeaderButtonWithText from "./HeaderButtonWithText"; -import { DocumentPlusIcon } from "@heroicons/react/24/outline"; +import { + BookmarkIcon, + DocumentPlusIcon, + FolderArrowDownIcon, +} from "@heroicons/react/24/outline"; import { ContextItem } from "../../../schema/FullState"; import { postVscMessage } from "../vscode"; import { GUIClientContext } from "../App"; import { MeiliSearch } from "meilisearch"; -import { setBottomMessage } from "../redux/slices/uiStateSlice"; +import { + setBottomMessage, + setDialogMessage, + setShowDialog, +} from "../redux/slices/uiStateSlice"; import { useDispatch, useSelector } from "react-redux"; import { RootStore } from "../redux/store"; @@ -29,6 +39,38 @@ const SEARCH_INDEX_NAME = "continue_context_items"; // #region styled components const mainInputFontSize = 13; +const MiniPillSpan = styled.span` + padding: 3px; + padding-left: 6px; + padding-right: 6px; + border-radius: ${defaultBorderRadius}; + color: ${vscForeground}; + background-color: #fff3; + overflow: hidden; + font-size: 12px; + display: flex; + align-items: center; + text-align: center; + justify-content: center; +`; + +const ContextGroupSelectDiv = styled.div` + display: flex; + align-items: center; + gap: 8px; + padding: 8px; + border-radius: ${defaultBorderRadius}; + background-color: ${secondaryDark}; + color: ${vscForeground}; + margin-top: 8px; + cursor: pointer; + + &:hover { + background-color: ${vscBackground}; + color: ${vscForeground}; + } +`; + const EmptyPillDiv = styled.div` padding: 4px; padding-left: 8px; @@ -137,6 +179,9 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { const workspacePaths = useSelector( (state: RootStore) => state.config.workspacePaths ); + const savedContextGroups = useSelector( + (state: RootStore) => state.serverState.saved_context_groups + ); const [history, setHistory] = React.useState([]); // The position of the current command you are typing now, so the one that will be appended to history once you press enter @@ -328,6 +373,87 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { }; }, [inputRef.current]); + const showSelectContextGroupDialog = () => { + dispatch( + setDialogMessage( +
+

Saved Context Groups

+ + {savedContextGroups && Object.keys(savedContextGroups).length > 0 ? ( +
+ {Object.keys(savedContextGroups).map((key: string) => { + const contextGroup = savedContextGroups[key]; + return ( + { + dispatch(setDialogMessage(undefined)); + dispatch(setShowDialog(false)); + client?.selectContextGroup(key); + }} + > + {key}: + + {contextGroup.map((contextItem) => { + return ( + + {contextItem.description.name} + + ); + })} + + ); + })} +
+ ) : ( +
No saved context groups
+ )} + +
+ ) + ); + dispatch(setShowDialog(true)); + }; + + const showDialogToSaveContextGroup = () => { + let inputElement: HTMLInputElement | null = null; + dispatch( + setDialogMessage( +
+ { + inputElement = input; + }} + /> +
+ +
+ ) + ); + dispatch(setShowDialog(true)); + }; + return ( <>
{ /> ); })} - {props.selectedContextItems.length > 0 && - (props.addingHighlightedCode ? ( - { - props.onToggleAddContext(); - }} - > - Highlight code section - - ) : ( + { + showSelectContextGroupDialog(); + }} + className="pill-button focus:outline-none focus:border-red-600 focus:border focus:border-solid" + onKeyDown={(e: KeyboardEvent) => { + e.preventDefault(); + if (e.key === "Enter") { + showSelectContextGroupDialog(); + } + }} + > + + + {props.selectedContextItems.length > 0 && ( + <> { - props.onToggleAddContext(); + showDialogToSaveContextGroup(); }} className="pill-button focus:outline-none focus:border-red-600 focus:border focus:border-solid" onKeyDown={(e: KeyboardEvent) => { e.preventDefault(); if (e.key === "Enter") { - props.onToggleAddContext(); + showDialogToSaveContextGroup(); } }} > - + - ))} + {props.addingHighlightedCode ? ( + { + props.onToggleAddContext(); + }} + > + Highlight code section + + ) : ( + { + props.onToggleAddContext(); + }} + className="pill-button focus:outline-none focus:border-red-600 focus:border focus:border-solid" + onKeyDown={(e: KeyboardEvent) => { + e.preventDefault(); + if (e.key === "Enter") { + props.onToggleAddContext(); + } + }} + > + + + )} + + )}