diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-08 14:17:50 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-08 14:17:50 -0700 |
commit | 8d423fd8d1d5b136e8138a906e8594ab93ec1982 (patch) | |
tree | 11498fcf6b11683a6b5aaab4ea7f7b2855d6c6cc | |
parent | 354a3f493074b1fb63ff4f206a94c35f05673e99 (diff) | |
download | sncontinue-8d423fd8d1d5b136e8138a906e8594ab93ec1982.tar.gz sncontinue-8d423fd8d1d5b136e8138a906e8594ab93ec1982.tar.bz2 sncontinue-8d423fd8d1d5b136e8138a906e8594ab93ec1982.zip |
feat: :lipstick: nested context provider dropdown
18 files changed, 319 insertions, 83 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index cee7a2f9..a943a35f 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -177,6 +177,7 @@ class Autopilot(ContinueBaseModel): session_info=self.session_info, config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, + context_providers=self.context_manager.get_provider_descriptions(), ) self.full_state = full_state return full_state diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 571e5dc8..f1f309ba 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -11,7 +11,13 @@ from ..libs.util.devdata import dev_data_logger from ..libs.util.logging import logger from ..libs.util.telemetry import posthog_logger from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch -from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId +from .main import ( + ChatMessage, + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProviderDescription, +) class ContinueSDK(BaseModel): @@ -36,6 +42,10 @@ class ContextProvider(BaseModel): delete_documents: Callable[[List[str]], Awaitable] = None update_documents: Callable[[List[ContextItem], str], Awaitable] = None + display_title: str + description: str + dynamic: bool + selected_items: List[ContextItem] = [] def dict(self, *args, **kwargs): @@ -168,6 +178,20 @@ class ContextManager: It is responsible for compiling all of this information into a single prompt without exceeding the token limit. """ + def get_provider_descriptions(self) -> List[ContextProviderDescription]: + """ + Returns a list of ContextProviderDescriptions for each context provider. + """ + return [ + ContextProviderDescription( + title=provider.title, + display_title=provider.display_title, + description=provider.description, + dynamic=provider.dynamic, + ) + for provider in self.context_providers.values() + ] + async def get_selected_items(self) -> List[ContextItem]: """ Returns all of the selected ContextItems. @@ -242,6 +266,7 @@ class ContextManager: "description": item.description.description, "content": item.content, "workspace_dir": workspace_dir, + "provider_name": item.description.id.provider_title, } for item in context_items ] @@ -282,7 +307,9 @@ class ContextManager: await globalSearchIndex.update_searchable_attributes( ["name", "description"] ) - await globalSearchIndex.update_filterable_attributes(["workspace_dir"]) + await globalSearchIndex.update_filterable_attributes( + ["workspace_dir", "provider_name"] + ) async def load_context_provider(provider: ContextProvider): context_items = await provider.provide_context_items(workspace_dir) @@ -293,6 +320,7 @@ class ContextManager: "description": item.description.description, "content": item.content, "workspace_dir": workspace_dir, + "provider_name": provider.title, } for item in context_items ] diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index ec2e2a07..3d3bef15 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -291,6 +291,13 @@ class ContinueConfig(ContinueBaseModel): return original_dict +class ContextProviderDescription(BaseModel): + title: str + display_title: str + description: str + dynamic: bool + + class FullState(ContinueBaseModel): """A full state of the program, including the history""" @@ -303,6 +310,7 @@ class FullState(ContinueBaseModel): session_info: Optional[SessionInfo] = None config: ContinueConfig saved_context_groups: Dict[str, List[ContextItem]] = {} + context_providers: List[ContextProviderDescription] = [] class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 85ca1969..653d2d6b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -137,6 +137,7 @@ class LLM(ContinueBaseModel): async def stream_complete( self, prompt: str, + raw: bool = False, model: str = None, temperature: float = None, top_p: float = None, @@ -163,7 +164,9 @@ class LLM(ContinueBaseModel): prompt = prune_raw_prompt_from_top( self.model, self.context_length, prompt, options.max_tokens ) - prompt = self.template_prompt_like_messages(prompt) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) self.write_log(f"Prompt: \n\n{prompt}") @@ -181,6 +184,7 @@ class LLM(ContinueBaseModel): async def complete( self, prompt: str, + raw: bool = False, model: str = None, temperature: float = None, top_p: float = None, @@ -207,7 +211,9 @@ class LLM(ContinueBaseModel): prompt = prune_raw_prompt_from_top( self.model, self.context_length, prompt, options.max_tokens ) - prompt = self.template_prompt_like_messages(prompt) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) self.write_log(f"Prompt: \n\n{prompt}") diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py index a234fa61..275473bc 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/edit.py +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -11,3 +11,5 @@ simplified_edit_prompt = dedent( Output nothing except for the code. No code block, no English explanation, no start/end tags. [/INST]""" ) + +codellama_infill_edit_prompt = "{{file_prefix}}<FILL>{{file_suffix}}" diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py index c8345d02..4c16cabf 100644 --- a/continuedev/src/continuedev/plugins/context_providers/diff.py +++ b/continuedev/src/continuedev/plugins/context_providers/diff.py @@ -7,6 +7,9 @@ from ...core.main import ContextItem, ContextItemDescription, ContextItemId class DiffContextProvider(ContextProvider): title = "diff" + display_title = "Diff" + description = "Output of 'git diff' in current repo" + dynamic = True DIFF_CONTEXT_ITEM_ID = "diff" @@ -30,8 +33,8 @@ class DiffContextProvider(ContextProvider): return [self.BASE_CONTEXT_ITEM] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if not id.item_id == self.DIFF_CONTEXT_ITEM_ID: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode( "utf-8" diff --git a/continuedev/src/continuedev/plugins/context_providers/dynamic.py b/continuedev/src/continuedev/plugins/context_providers/dynamic.py index e6ea0e88..50567621 100644 --- a/continuedev/src/continuedev/plugins/context_providers/dynamic.py +++ b/continuedev/src/continuedev/plugins/context_providers/dynamic.py @@ -13,16 +13,12 @@ class DynamicProvider(ContextProvider, ABC): """ title: str - """ - A name representing the provider. Probably use capitalized version of title - """ + """A name representing the provider. Probably use capitalized version of title""" + name: str - """ - A description for the provider - """ - description: str workspace_dir: str = None + dynamic: bool = True @property def BASE_CONTEXT_ITEM(self): @@ -41,8 +37,8 @@ class DynamicProvider(ContextProvider, ABC): return [self.BASE_CONTEXT_ITEM] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if id.item_id != self.title: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") query = query.lstrip(self.title + " ") results = await self.get_content(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py index 3f37232e..bd63eab8 100644 --- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -17,6 +17,10 @@ class EmbeddingResult(BaseModel): class EmbeddingsProvider(ContextProvider): title = "embed" + display_title = "Embeddings Search" + description = "Search the codebase using embeddings" + dynamic = True + workspace_directory: str EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" @@ -62,8 +66,8 @@ class EmbeddingsProvider(ContextProvider): return [self.BASE_CONTEXT_ITEM] async def add_context_item(self, id: ContextItemId, query: str): - if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") results = await self._get_query_results(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 6b27a889..c4a61193 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -27,6 +27,10 @@ class FileContextProvider(ContextProvider): title = "file" ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS + display_title = "Files" + description = "Reference files in the current workspace" + dynamic = False + async def start(self, *args): await super().start(*args) diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py index 959a0a66..968b761d 100644 --- a/continuedev/src/continuedev/plugins/context_providers/filetree.py +++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py @@ -35,6 +35,9 @@ def split_path(path: str, with_root=None) -> List[str]: class FileTreeContextProvider(ContextProvider): title = "tree" + display_title = "File Tree" + description = "Add a formatted file tree of this directory to the context" + dynamic = True workspace_dir: str = None @@ -78,7 +81,7 @@ class FileTreeContextProvider(ContextProvider): return [await self._filetree_context_item()] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if not id.item_id == self.title: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") return await self._filetree_context_item() diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py index d394add1..7a16d3c9 100644 --- a/continuedev/src/continuedev/plugins/context_providers/github.py +++ b/continuedev/src/continuedev/plugins/context_providers/github.py @@ -20,6 +20,10 @@ class GitHubIssuesContextProvider(ContextProvider): repo_name: str auth_token: str + display_title = "GitHub Issues" + description = "Reference GitHub issues" + dynamic = False + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: auth = Auth.Token(self.auth_token) gh = Github(auth=auth) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index fc9d4555..06681db0 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars class GoogleContextProvider(ContextProvider): title = "google" + display_title = "Google" + description = "Search Google" + dynamic = True serper_api_key: str @@ -42,8 +45,8 @@ class GoogleContextProvider(ContextProvider): return [self.BASE_CONTEXT_ITEM] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if not id.item_id == self.GOOGLE_CONTEXT_ITEM_ID: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") results = await self._google_search(query) json_results = json.loads(results) diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 504764b9..0610a8c3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -28,6 +28,9 @@ class HighlightedCodeContextProvider(ContextProvider): """ title = "code" + display_title = "Highlighted Code" + description = "Highlight code" + dynamic = True ide: Any # IdeProtocolServer diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py index 8a3a3689..19fc15bc 100644 --- a/continuedev/src/continuedev/plugins/context_providers/search.py +++ b/continuedev/src/continuedev/plugins/context_providers/search.py @@ -11,6 +11,9 @@ from .util import remove_meilisearch_disallowed_chars class SearchContextProvider(ContextProvider): title = "search" + display_title = "Search" + description = "Search the workspace for all matches of an exact string (e.g. '@search console.log')" + dynamic = True SEARCH_CONTEXT_ITEM_ID = "search" @@ -86,8 +89,8 @@ class SearchContextProvider(ContextProvider): return [self.BASE_CONTEXT_ITEM] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if not id.item_id == self.SEARCH_CONTEXT_ITEM_ID: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") query = query.lstrip("search ") results = await self._search(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/terminal.py b/continuedev/src/continuedev/plugins/context_providers/terminal.py index 57d06a4f..f63ed676 100644 --- a/continuedev/src/continuedev/plugins/context_providers/terminal.py +++ b/continuedev/src/continuedev/plugins/context_providers/terminal.py @@ -6,6 +6,9 @@ from ...core.main import ChatMessage, ContextItem, ContextItemDescription, Conte class TerminalContextProvider(ContextProvider): title = "terminal" + display_title = "Terminal" + description = "Reference the contents of the terminal" + dynamic = True workspace_dir: str = None get_last_n_commands: int = 3 @@ -31,8 +34,8 @@ class TerminalContextProvider(ContextProvider): return [self._terminal_context_item()] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: - if not id.item_id == self.title: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") terminal_contents = await self.sdk.ide.getTerminalContents( self.get_last_n_commands diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py index a5ec8990..b9dc0e1d 100644 --- a/continuedev/src/continuedev/plugins/context_providers/url.py +++ b/continuedev/src/continuedev/plugins/context_providers/url.py @@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars class URLContextProvider(ContextProvider): title = "url" + display_title = "URL" + description = "Reference the contents of a webpage" + dynamic = True # Allows users to provide a list of preset urls preset_urls: List[str] = [] @@ -78,8 +81,8 @@ class URLContextProvider(ContextProvider): return matching_static_item # Check if the item is the dynamic item - if not id.item_id == self.DYNAMIC_URL_CONTEXT_ITEM_ID: - raise Exception("Invalid item id") + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") # Generate the dynamic item url = query.lstrip("url ").strip() diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 97235e6f..bf5eb144 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -604,7 +604,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user rendered = render_prompt_template( template, messages[:-1], - {"code_to_edit": rif.contents, "user_input": self.user_input}, + { + "code_to_edit": rif.contents, + "user_input": self.user_input, + "file_prefix": file_prefix, + "file_suffix": file_suffix, + }, ) if isinstance(rendered, str): messages = [ @@ -617,11 +622,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user else: messages = rendered - generator = model_to_use.stream_chat( - messages, - temperature=sdk.config.temperature, - max_tokens=min(max_tokens, model_to_use.context_length // 2), - ) + generator = model_to_use.stream_complete( + rendered, + raw=True, + temperature=sdk.config.temperature, + max_tokens=min(max_tokens, model_to_use.context_length // 2), + ) + + else: + + async def gen(): + async for chunk in model_to_use.stream_chat( + messages, + temperature=sdk.config.temperature, + max_tokens=min(max_tokens, model_to_use.context_length // 2), + ): + if "content" in chunk: + yield chunk["content"] + + generator = gen() posthog_logger.capture_event( "model_use", @@ -641,9 +660,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user return # Accumulate lines - if "content" not in chunk: - continue # ayo - chunk = chunk["content"] chunk_lines = chunk.split("\n") chunk_lines[0] = unfinished_line + chunk_lines[0] if chunk.endswith("\n"): diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index 41b44684..cf7c4298 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -1,4 +1,5 @@ import React, { + useCallback, useContext, useEffect, useImperativeHandle, @@ -19,6 +20,7 @@ import { BookmarkIcon, DocumentPlusIcon, FolderArrowDownIcon, + ArrowLeftIcon, } from "@heroicons/react/24/outline"; import { ContextItem } from "../../../schema/FullState"; import { postVscMessage } from "../vscode"; @@ -164,37 +166,53 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { const [items, setItems] = React.useState(props.items); const inputRef = React.useRef<HTMLInputElement>(null); - const [inputBoxHeight, setInputBoxHeight] = useState<string | undefined>( - undefined - ); // Whether the current input follows an '@' and should be treated as context query const [currentlyInContextQuery, setCurrentlyInContextQuery] = useState(false); + const [nestedContextProvider, setNestedContextProvider] = useState< + any | undefined + >(undefined); - const { getInputProps, ...downshiftProps } = useCombobox({ - onSelectedItemChange: ({ selectedItem }) => { - if (selectedItem?.id) { - // Get the query from the input value - const segs = downshiftProps.inputValue.split("@"); - const query = segs[segs.length - 1]; - const restOfInput = segs.splice(0, segs.length - 1).join("@"); + useEffect(() => { + if (!currentlyInContextQuery) { + setNestedContextProvider(undefined); + } + }, [currentlyInContextQuery]); - // Tell server the context item was selected - client?.selectContextItem(selectedItem.id, query); + const contextProviders = useSelector( + (state: RootStore) => state.serverState.context_providers + ) as any[]; - // Remove the '@' and the context query from the input - if (downshiftProps.inputValue.includes("@")) { - downshiftProps.setInputValue(restOfInput); - } - } - }, - onInputValueChange({ inputValue, highlightedIndex }) { + const goBackToContextProviders = () => { + setCurrentlyInContextQuery(false); + setNestedContextProvider(undefined); + downshiftProps.setInputValue("@"); + }; + + useEffect(() => { + if (!nestedContextProvider) { + console.log("setting items", nestedContextProvider); + setItems( + contextProviders?.map((provider) => ({ + name: provider.display_title, + description: provider.description, + id: provider.title, + })) || [] + ); + } + }, [nestedContextProvider]); + + const onInputValueChangeCallback = useCallback( + ({ inputValue, highlightedIndex }: any) => { + // Clear the input if (!inputValue) { setItems([]); + setNestedContextProvider(undefined); return; } props.onInputValueChange(inputValue); + // Handle context selection if (inputValue.endsWith("@") || currentlyInContextQuery) { const segs = inputValue?.split("@") || []; @@ -202,46 +220,114 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { // Get search results and return setCurrentlyInContextQuery(true); const providerAndQuery = segs[segs.length - 1] || ""; - // Only return context items from the current workspace - the index is currently shared between all sessions - const workspaceFilter = - workspacePaths && workspacePaths.length > 0 - ? `workspace_dir IN [ ${workspacePaths - .map((path) => `"${path}"`) - .join(", ")} ]` - : undefined; - searchClient - .index(SEARCH_INDEX_NAME) - .search(providerAndQuery, { - filter: workspaceFilter, - }) - .then((res) => { - setItems( - res.hits.map((hit) => { - return { - name: hit.name, - description: hit.description, - id: hit.id, - content: hit.content, - }; - }) - ); - }) - .catch(() => { - // Swallow errors, because this simply is not supported on Windows at the moment + + if (nestedContextProvider && !inputValue.endsWith("@")) { + // Search only within this specific context provider + getFilteredContextItemsForProvider( + nestedContextProvider.title, + providerAndQuery + ).then((res) => { + setItems(res); }); + } else { + // Search through the list of context providers + const filteredItems = + contextProviders + ?.filter( + (provider) => + `@${provider.title}` + .toLowerCase() + .startsWith(inputValue.toLowerCase()) || + `@${provider.display_title}` + .toLowerCase() + .startsWith(inputValue.toLowerCase()) + ) + .map((provider) => ({ + name: provider.display_title, + description: provider.description, + id: provider.title, + })) || []; + setItems(filteredItems); + setCurrentlyInContextQuery(true); + } return; } else { // Exit the '@' context menu setCurrentlyInContextQuery(false); - setItems; + setNestedContextProvider(undefined); } } + + setNestedContextProvider(undefined); + + // Handle slash commands setItems( props.items.filter((item) => item.name.toLowerCase().startsWith(inputValue.toLowerCase()) ) ); }, + [props.items, currentlyInContextQuery, nestedContextProvider] + ); + + const getFilteredContextItemsForProvider = async ( + provider: string, + query: string + ) => { + // Only return context items from the current workspace - the index is currently shared between all sessions + const workspaceFilter = + workspacePaths && workspacePaths.length > 0 + ? `workspace_dir IN [ ${workspacePaths + .map((path) => `"${path}"`) + .join(", ")} ] AND provider_name = '${provider}'` + : undefined; + try { + const res = await searchClient.index(SEARCH_INDEX_NAME).search(query, { + filter: workspaceFilter, + }); + return ( + res?.hits.map((hit) => { + return { + name: hit.name, + description: hit.description, + id: hit.id, + content: hit.content, + }; + }) || [] + ); + } catch (e) { + console.log("Error searching context items", e); + return []; + } + }; + + const { getInputProps, ...downshiftProps } = useCombobox({ + onSelectedItemChange: ({ selectedItem }) => { + if (!selectedItem) return; + if (selectedItem.id) { + // Get the query from the input value + const segs = downshiftProps.inputValue.split("@"); + const query = segs[segs.length - 1]; + + // Tell server the context item was selected + client?.selectContextItem(selectedItem.id, query); + if (downshiftProps.inputValue.includes("@")) { + const selectedNestedContextProvider = contextProviders.find( + (provider) => provider.title === selectedItem.id + ); + if ( + !nestedContextProvider && + !selectedNestedContextProvider?.dynamic + ) { + downshiftProps.setInputValue(`@${selectedItem.id} `); + setNestedContextProvider(selectedNestedContextProvider); + } else { + downshiftProps.setInputValue(""); + } + } + } + }, + onInputValueChange: onInputValueChangeCallback, items, itemToString(item) { return item ? item.name : ""; @@ -467,7 +553,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { target.scrollHeight, 300 ).toString()}px`; - setInputBoxHeight(target.style.height); // setShowContextDropdown(target.value.endsWith("@")); }, @@ -498,6 +583,31 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { props.onEnter(event); } setCurrentlyInContextQuery(false); + } else if ( + event.key === "Enter" && + currentlyInContextQuery && + nestedContextProvider === undefined + ) { + const newProviderName = + items[downshiftProps.highlightedIndex].name; + const newProvider = contextProviders.find( + (provider) => provider.display_title === newProviderName + ); + + if (!newProvider) { + (event.nativeEvent as any).preventDownshiftDefault = true; + return; + } else if (newProvider.dynamic) { + return; + } + + setNestedContextProvider(newProvider); + downshiftProps.setInputValue(`@${newProvider.title} `); + (event.nativeEvent as any).preventDownshiftDefault = true; + event.preventDefault(); + getFilteredContextItemsForProvider(newProvider.title, "").then( + (items) => setItems(items) + ); } else if (event.key === "Tab" && items.length > 0) { downshiftProps.setInputValue(items[0].name); event.preventDefault(); @@ -545,6 +655,12 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { ); setCurrentlyInContextQuery(false); } else if (event.key === "Escape") { + if (nestedContextProvider) { + goBackToContextProviders(); + (event.nativeEvent as any).preventDownshiftDefault = true; + return; + } + setCurrentlyInContextQuery(false); if (downshiftProps.isOpen && items.length > 0) { downshiftProps.closeMenu(); @@ -578,6 +694,30 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { ulHeightPixels={ulRef.current?.getBoundingClientRect().height || 0} hidden={!downshiftProps.isOpen || items.length === 0} > + {nestedContextProvider && ( + <div + style={{ + backgroundColor: secondaryDark, + borderBottom: `1px solid ${lightGray}`, + display: "flex", + gap: "4px", + position: "sticky", + top: "0px", + }} + className="py-2 px-4 my-0" + > + <ArrowLeftIcon + width="1.4em" + height="1.4em" + className="cursor-pointer" + onClick={() => { + goBackToContextProviders(); + }} + /> + {nestedContextProvider.display_title} -{" "} + {nestedContextProvider.description} + </div> + )} {downshiftProps.isOpen && items.map((item, index) => ( <Li @@ -586,6 +726,12 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { {...downshiftProps.getItemProps({ item, index })} highlighted={downshiftProps.highlightedIndex === index} selected={downshiftProps.selectedItem === item} + onClick={(e) => { + e.stopPropagation(); + e.preventDefault(); + (e.nativeEvent as any).preventDownshiftDefault = true; + downshiftProps.selectItem(item); + }} > <span> {item.name} |