diff options
Diffstat (limited to 'server/continuedev/plugins/context_providers')
13 files changed, 1121 insertions, 0 deletions
diff --git a/server/continuedev/plugins/context_providers/__init__.py b/server/continuedev/plugins/context_providers/__init__.py new file mode 100644 index 00000000..0123bb7b --- /dev/null +++ b/server/continuedev/plugins/context_providers/__init__.py @@ -0,0 +1,7 @@ +from .diff import DiffContextProvider # noqa: F401 +from .filetree import FileTreeContextProvider # noqa: F401 +from .github import GitHubIssuesContextProvider # noqa: F401 +from .google import GoogleContextProvider # noqa: F401 +from .search import SearchContextProvider # noqa: F401 +from .terminal import TerminalContextProvider # noqa: F401 +from .url import URLContextProvider # noqa: F401 diff --git a/server/continuedev/plugins/context_providers/diff.py b/server/continuedev/plugins/context_providers/diff.py new file mode 100644 index 00000000..05da3547 --- /dev/null +++ b/server/continuedev/plugins/context_providers/diff.py @@ -0,0 +1,73 @@ +import subprocess +from typing import List + +from pydantic import Field + +from ...core.context import ContextProvider +from ...core.main import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContinueCustomException, +) + + +class DiffContextProvider(ContextProvider): + """ + Type '@diff' to reference all of the changes you've made to your current branch. This is useful if you want to summarize what you've done or ask for a general review of your work before committing. + """ + + title = "diff" + display_title = "Diff" + description = "Output of 'git diff' in current repo" + dynamic = True + + _DIFF_CONTEXT_ITEM_ID = "diff" + + workspace_dir: str = Field( + None, description="The workspace directory in which to run `git diff`" + ) + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Diff", + description="Reference the output of 'git diff' for the current workspace", + id=ContextItemId( + provider_title=self.title, item_id=self._DIFF_CONTEXT_ITEM_ID + ), + ), + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + result = subprocess.run( + ["git", "diff"], cwd=self.workspace_dir, capture_output=True, text=True + ) + diff = result.stdout + error = result.stderr + if error.strip() != "": + if error.startswith("warning: Not a git repository"): + raise ContinueCustomException( + title="Not a git repository", + message="The @diff context provider only works in git repositories.", + ) + raise ContinueCustomException( + title="Error running git diff", + message=f"Error running git diff:\n\n{error}", + ) + + if diff.strip() == "": + diff = "No changes" + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = diff + return ctx_item diff --git a/server/continuedev/plugins/context_providers/dynamic.py b/server/continuedev/plugins/context_providers/dynamic.py new file mode 100644 index 00000000..50567621 --- /dev/null +++ b/server/continuedev/plugins/context_providers/dynamic.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from typing import List + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.util.create_async_task import create_async_task +from .util import remove_meilisearch_disallowed_chars + + +class DynamicProvider(ContextProvider, ABC): + """ + A title representing the provider + """ + + title: str + """A name representing the provider. Probably use capitalized version of title""" + + name: str + + workspace_dir: str = None + dynamic: bool = True + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name=self.name, + description=self.description, + id=ContextItemId(provider_title=self.title, item_id=self.title), + ), + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + create_async_task(self.setup()) + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + query = query.lstrip(self.title + " ") + results = await self.get_content(query) + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = results + ctx_item.description.name = f"{self.name}: '{query}'" + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) + return ctx_item + + @abstractmethod + async def get_content(self, query: str) -> str: + """Retrieve the content given the query + (e.g. search the codebase, return search results)""" + raise NotImplementedError + + @abstractmethod + async def setup(self): + """Run any setup needed (e.g. indexing the codebase)""" + raise NotImplementedError + + +""" +class ExampleDynamicProvider(DynamicProvider): + title = "example" + name = "Example" + description = "Example description" + + async def get_content(self, query: str) -> str: + return f"Example content for '{query}'" + + async def setup(self): + print("Example setup") +""" diff --git a/server/continuedev/plugins/context_providers/embeddings.py b/server/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..86cba311 --- /dev/null +++ b/server/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,81 @@ +import os +import uuid +from typing import List, Optional + +from pydantic import BaseModel + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager + + +class EmbeddingResult(BaseModel): + filename: str + content: str + + +class EmbeddingsProvider(ContextProvider): + title = "embed" + + display_title = "Embeddings Search" + description = "Search the codebase using embeddings" + dynamic = True + requires_query = True + + workspace_directory: str + + EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + + index_manager: Optional[ChromaIndexManager] = None + + class Config: + arbitrary_types_allowed = True + + @property + def index(self): + if self.index_manager is None: + self.index_manager = ChromaIndexManager(self.workspace_directory) + return self.index_manager + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Embedding Search", + description="Enter a query to embedding search codebase", + id=ContextItemId( + provider_title=self.title, item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ), + ), + ) + + async def _get_query_results(self, query: str) -> str: + results = self.index.query_codebase_index(query) + + ret = [] + for node in results.source_nodes: + resource_name = list(node.node.relationships.values())[0] + filepath = resource_name[: resource_name.index("::")] + ret.append(EmbeddingResult(filename=filepath, content=node.node.text)) + + return ret + + async def provide_context_items(self) -> List[ContextItem]: + self.index.create_codebase_index() # TODO Synchronous here is not ideal + + return [self.BASE_CONTEXT_ITEM] + + async def add_context_item(self, id: ContextItemId, query: str): + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + results = await self._get_query_results(query) + + for i in range(len(results)): + result = results[i] + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.description.name = os.path.basename(result.filename) + ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" + ctx_item.description.id.item_id = uuid.uuid4().hex + self.selected_items.append(ctx_item) diff --git a/server/continuedev/plugins/context_providers/file.py b/server/continuedev/plugins/context_providers/file.py new file mode 100644 index 00000000..4cfbcfdb --- /dev/null +++ b/server/continuedev/plugins/context_providers/file.py @@ -0,0 +1,136 @@ +import asyncio +import os +from typing import List, Optional + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...core.sdk import ContinueSDK +from ...libs.util.filter_files import DEFAULT_IGNORE_PATTERNS +from ...libs.util.logging import logger +from .util import remove_meilisearch_disallowed_chars + +MAX_SIZE_IN_CHARS = 50_000 + + +async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str: + try: + return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS] + except Exception as _: + return None + + +class FileContextProvider(ContextProvider): + """ + The FileContextProvider is a ContextProvider that allows you to search files in the open workspace. + """ + + title = "file" + ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS + + display_title = "Files" + description = "Reference files in the current workspace" + dynamic = False + + async def start(self, *args): + await super().start(*args) + + async def on_file_saved(filepath: str, contents: str): + item = await self.get_context_item_for_filepath(filepath) + if item is None: + return + await self.update_documents([item], self.sdk.ide.workspace_directory) + + async def on_files_created(filepaths: List[str]): + items = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in filepaths + ] + ) + items = [item for item in items if item is not None] + await self.update_documents(items, self.sdk.ide.workspace_directory) + + async def on_files_deleted(filepaths: List[str]): + ids = [self.get_id_for_filepath(filepath) for filepath in filepaths] + + await self.delete_documents(ids) + + async def on_files_renamed(old_filepaths: List[str], new_filepaths: List[str]): + if self.sdk.ide.workspace_directory is None: + return + + old_ids = [self.get_id_for_filepath(filepath) for filepath in old_filepaths] + new_docs = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in new_filepaths + ] + ) + new_docs = [doc for doc in new_docs if doc is not None] + + await self.delete_documents(old_ids) + await self.update_documents(new_docs, self.sdk.ide.workspace_directory) + + self.sdk.ide.subscribeToFileSaved(on_file_saved) + self.sdk.ide.subscribeToFilesCreated(on_files_created) + self.sdk.ide.subscribeToFilesDeleted(on_files_deleted) + self.sdk.ide.subscribeToFilesRenamed(on_files_renamed) + + def get_id_for_filepath(self, absolute_filepath: str) -> str: + return remove_meilisearch_disallowed_chars(absolute_filepath) + + async def get_context_item_for_filepath( + self, absolute_filepath: str + ) -> Optional[ContextItem]: + content = await get_file_contents(absolute_filepath, self.sdk) + if content is None: + return None + + workspace_dir = self.sdk.ide.workspace_directory + if ( + os.path.splitdrive(workspace_dir)[0] + != os.path.splitdrive(absolute_filepath)[0] + ): + workspace_dir = ( + os.path.splitdrive(absolute_filepath)[0] + + os.path.splitdrive(workspace_dir)[1] + ) + + try: + relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) + except Exception as e: + logger.warning(f"Error getting relative path: {e}") + return None + + return ContextItem( + content=content[: min(2000, len(content))], + description=ContextItemDescription( + name=os.path.basename(absolute_filepath), + # We should add the full path to the ContextItem + # It warrants a data modeling discussion and has no immediate use case + description=relative_to_workspace, + id=ContextItemId( + provider_title=self.title, + item_id=self.get_id_for_filepath(absolute_filepath), + ), + ), + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + contents = await self.sdk.ide.listDirectoryContents(workspace_dir, True) + if contents is None: + return [] + + absolute_filepaths: List[str] = [] + for filepath in contents[:1000]: + absolute_filepaths.append(filepath) + + items = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in absolute_filepaths + ] + ) + items = list(filter(lambda item: item is not None, items)) + + return items diff --git a/server/continuedev/plugins/context_providers/filetree.py b/server/continuedev/plugins/context_providers/filetree.py new file mode 100644 index 00000000..5b3d3a50 --- /dev/null +++ b/server/continuedev/plugins/context_providers/filetree.py @@ -0,0 +1,89 @@ +from typing import List + +from pydantic import BaseModel, Field + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId + + +class Directory(BaseModel): + name: str + files: List[str] + directories: List["Directory"] + + +def format_file_tree(tree: Directory, indentation: str = "") -> str: + result = "" + for file in tree.files: + result += f"{indentation}{file}\n" + + for directory in tree.directories: + result += f"{indentation}{directory.name}/\n" + result += format_file_tree(directory, indentation + " ") + + return result + + +def split_path(path: str, with_root=None) -> List[str]: + parts = path.split("/") if "/" in path else path.split("\\") + if with_root is not None: + root_parts = split_path(with_root) + parts = parts[len(root_parts) - 1 :] + + return parts + + +class FileTreeContextProvider(ContextProvider): + """Type '@tree' to reference the contents of your current workspace. The LLM will be able to see the nested directory structure of your project.""" + + title = "tree" + display_title = "File Tree" + description = "Add a formatted file tree of this directory to the context" + dynamic = True + + workspace_dir: str = Field(None, description="The workspace directory to display") + + async def _get_file_tree(self, directory: str) -> str: + contents = await self.sdk.ide.listDirectoryContents(directory, recursive=True) + + tree = Directory( + name=split_path(self.workspace_dir)[-1], files=[], directories=[] + ) + + for file in contents: + parts = split_path(file, with_root=self.workspace_dir) + + current_tree = tree + for part in parts[:-1]: + if part not in [d.name for d in current_tree.directories]: + current_tree.directories.append( + Directory(name=part, files=[], directories=[]) + ) + + current_tree = [d for d in current_tree.directories if d.name == part][ + 0 + ] + + current_tree.files.append(parts[-1]) + + return format_file_tree(tree) + + async def _filetree_context_item(self): + return ContextItem( + content=await self._get_file_tree(self.workspace_dir), + description=ContextItemDescription( + name="File Tree", + description="Add a formatted file tree of this directory to the context", + id=ContextItemId(provider_title=self.title, item_id=self.title), + ), + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + return [await self._filetree_context_item()] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + return await self._filetree_context_item() diff --git a/server/continuedev/plugins/context_providers/github.py b/server/continuedev/plugins/context_providers/github.py new file mode 100644 index 00000000..c031f310 --- /dev/null +++ b/server/continuedev/plugins/context_providers/github.py @@ -0,0 +1,49 @@ +from typing import List + +from github import Auth, Github +from pydantic import Field + +from ...core.context import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProvider, +) + + +class GitHubIssuesContextProvider(ContextProvider): + """ + The GitHubIssuesContextProvider is a ContextProvider that allows you to search GitHub issues in a repo. Type '@issue' to reference the title and contents of an issue. + """ + + title = "issues" + repo_name: str = Field( + ..., description="The name of the GitHub repo from which to pull issues" + ) + auth_token: str = Field( + ..., + description="The GitHub auth token to use to authenticate with the GitHub API", + ) + + display_title = "GitHub Issues" + description = "Reference GitHub issues" + dynamic = False + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + auth = Auth.Token(self.auth_token) + gh = Github(auth=auth) + + repo = gh.get_repo(self.repo_name) + issues = repo.get_issues().get_page(0) + + return [ + ContextItem( + content=issue.body, + description=ContextItemDescription( + name=f"Issue #{issue.number}", + description=issue.title, + id=ContextItemId(provider_title=self.title, item_id=issue.id), + ), + ) + for issue in issues + ] diff --git a/server/continuedev/plugins/context_providers/google.py b/server/continuedev/plugins/context_providers/google.py new file mode 100644 index 00000000..852f4e9a --- /dev/null +++ b/server/continuedev/plugins/context_providers/google.py @@ -0,0 +1,70 @@ +import json +from typing import List + +import aiohttp +from pydantic import Field + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars + + +class GoogleContextProvider(ContextProvider): + """Type '@google' to reference the results of a Google search. For example, type "@google python tutorial" if you want to search and discuss ways of learning Python.""" + + title = "google" + display_title = "Google" + description = "Search Google" + dynamic = True + requires_query = True + + serper_api_key: str = Field( + ..., + description="Your SerpAPI key, used to programmatically make Google searches. You can get a key at https://serper.dev.", + ) + + _GOOGLE_CONTEXT_ITEM_ID = "google_search" + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Google Search", + description="Enter a query to search google", + id=ContextItemId( + provider_title=self.title, item_id=self._GOOGLE_CONTEXT_ITEM_ID + ), + ), + ) + + async def _google_search(self, query: str) -> str: + url = "https://google.serper.dev/search" + + payload = json.dumps({"q": query}) + headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"} + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, data=payload) as response: + return await response.text() + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + results = await self._google_search(query) + json_results = json.loads(results) + content = f"Google Search: {query}\n\n" + if answerBox := json_results.get("answerBox"): + content += f"Answer Box ({answerBox['title']}): {answerBox['answer']}\n\n" + + for result in json_results["organic"]: + content += f"{result['title']}\n{result['link']}\n{result['snippet']}\n\n" + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = content + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) + return ctx_item diff --git a/server/continuedev/plugins/context_providers/highlighted_code.py b/server/continuedev/plugins/context_providers/highlighted_code.py new file mode 100644 index 00000000..3304a71d --- /dev/null +++ b/server/continuedev/plugins/context_providers/highlighted_code.py @@ -0,0 +1,293 @@ +import os +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from ...core.context import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProvider, +) +from ...core.main import ChatMessage +from ...models.filesystem import RangeInFileWithContents +from ...models.main import Range + + +class HighlightedRangeContextItem(BaseModel): + rif: RangeInFileWithContents + item: ContextItem + + +class HighlightedCodeContextProvider(ContextProvider): + """ + The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. + When you type '@', the context provider will be asked to populate a list of options. + These options will be updated on each keystroke. + When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). + """ + + title = "code" + display_title = "Highlighted Code" + description = "Highlight code" + dynamic = True + + ide: Any # IdeProtocolServer + + highlighted_ranges: List[HighlightedRangeContextItem] = [] + adding_highlighted_code: bool = True + # Controls whether you can have more than one highlighted range. Now always True. + + should_get_fallback_context_item: bool = True + last_added_fallback: bool = False + + async def _get_fallback_context_item(self) -> HighlightedRangeContextItem: + # Used to automatically include the currently open file. Disabled for now. + return None + + if not self.should_get_fallback_context_item: + return None + + visible_files = await self.ide.getVisibleFiles() + if len(visible_files) > 0: + content = await self.ide.readFile(visible_files[0]) + rif = RangeInFileWithContents.from_entire_file(visible_files[0], content) + + item = self._rif_to_context_item(rif, 0, True) + item.description.name = self._rif_to_name(rif, show_line_nums=False) + + self.last_added_fallback = True + return HighlightedRangeContextItem(rif=rif, item=item) + + return None + + async def get_selected_items(self) -> List[ContextItem]: + items = [hr.item for hr in self.highlighted_ranges] + + if len(items) == 0 and ( + fallback_item := await self._get_fallback_context_item() + ): + items = [fallback_item.item] + + return items + + async def get_chat_messages(self) -> List[ContextItem]: + ranges = self.highlighted_ranges + if len(ranges) == 0 and ( + fallback_item := await self._get_fallback_context_item() + ): + ranges = [fallback_item] + + return [ + ChatMessage( + role="user", + content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", + summary=f"Code in this file is highlighted: {r.rif.filepath}", + ) + for r in ranges + ] + + def _make_sure_is_editing_range(self): + """If none of the highlighted ranges are currently being edited, the first should be selected""" + if len(self.highlighted_ranges) == 0: + return + if not any(map(lambda x: x.item.editing, self.highlighted_ranges)): + self.highlighted_ranges[0].item.editing = True + + def _disambiguate_highlighted_ranges(self): + """If any files have the same name, also display their folder name""" + name_status: Dict[ + str, set + ] = {} # basename -> set of full paths with that basename + for hr in self.highlighted_ranges: + basename = os.path.basename(hr.rif.filepath) + if basename in name_status: + name_status[basename].add(hr.rif.filepath) + else: + name_status[basename] = {hr.rif.filepath} + + for hr in self.highlighted_ranges: + basename = os.path.basename(hr.rif.filepath) + if len(name_status[basename]) > 1: + hr.item.description.name = self._rif_to_name( + hr.rif, + display_filename=os.path.join( + os.path.basename(os.path.dirname(hr.rif.filepath)), basename + ), + ) + else: + hr.item.description.name = self._rif_to_name( + hr.rif, display_filename=basename + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + return [] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + raise NotImplementedError() + + async def clear_context(self): + self.highlighted_ranges = [] + self.adding_highlighted_code = False + self.should_get_fallback_context_item = True + self.last_added_fallback = False + + async def delete_context_with_ids( + self, ids: List[ContextItemId] + ) -> List[ContextItem]: + ids_to_delete = [id.item_id for id in ids] + + kept_ranges = [] + for hr in self.highlighted_ranges: + if hr.item.description.id.item_id not in ids_to_delete: + kept_ranges.append(hr) + self.highlighted_ranges = kept_ranges + + self._make_sure_is_editing_range() + + if len(self.highlighted_ranges) == 0 and self.last_added_fallback: + self.should_get_fallback_context_item = False + + return [hr.item for hr in self.highlighted_ranges] + + def _rif_to_name( + self, + rif: RangeInFileWithContents, + display_filename: str = None, + show_line_nums: bool = True, + ) -> str: + line_nums = ( + f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" + if show_line_nums + else "" + ) + return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}" + + def _rif_to_context_item( + self, rif: RangeInFileWithContents, idx: int, editing: bool + ) -> ContextItem: + return ContextItem( + description=ContextItemDescription( + name=self._rif_to_name(rif), + description=rif.filepath, + id=ContextItemId(provider_title=self.title, item_id=str(idx)), + ), + content=rif.contents, + editing=editing if editing is not None else False, + editable=True, + ) + + async def handle_highlighted_code( + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, + ): + self.should_get_fallback_context_item = True + self.last_added_fallback = False + + # Filter out rifs from ~/.continue/diffs folder + range_in_files = [ + rif + for rif in range_in_files + if not os.path.dirname(rif.filepath) + == os.path.expanduser("~/.continue/diffs") + ] + + # If not adding highlighted code + if not self.adding_highlighted_code: + if ( + len(self.highlighted_ranges) == 1 + and len(range_in_files) <= 1 + and ( + len(range_in_files) == 0 + or range_in_files[0].range.start == range_in_files[0].range.end + ) + ): + # If un-highlighting the range to edit, then remove the range + self.highlighted_ranges = [] + elif len(range_in_files) > 0: + # Otherwise, replace the current range with the new one + # This is the first range to be highlighted + self.highlighted_ranges = [ + HighlightedRangeContextItem( + rif=range_in_files[0], + item=self._rif_to_context_item(range_in_files[0], 0, edit), + ) + ] + + return + + # If editing, make sure none of the other ranges are editing + if edit: + for hr in self.highlighted_ranges: + hr.item.editing = False + + # If new range overlaps with any existing, keep the existing but merged + new_ranges = [] + for i, new_hr in enumerate(range_in_files): + found_overlap_with = None + for existing_rif in self.highlighted_ranges: + if ( + new_hr.filepath == existing_rif.rif.filepath + and new_hr.range.overlaps_with(existing_rif.rif.range) + ): + existing_rif.rif.range = existing_rif.rif.range.merge_with( + new_hr.range + ) + found_overlap_with = existing_rif + break + + if found_overlap_with is None: + new_ranges.append( + HighlightedRangeContextItem( + rif=new_hr, + item=self._rif_to_context_item( + new_hr, len(self.highlighted_ranges) + i, edit + ), + ) + ) + elif edit: + # Want to update the range so it's only the newly selected portion + found_overlap_with.rif.range = new_hr.range + found_overlap_with.item.editing = True + + self.highlighted_ranges = self.highlighted_ranges + new_ranges + + self._make_sure_is_editing_range() + self._disambiguate_highlighted_ranges() + + async def set_editing_at_ids(self, ids: List[str]): + # Don't do anything if there are no valid ids here + count = 0 + for hr in self.highlighted_ranges: + if hr.item.description.id.item_id in ids: + count += 1 + + if count == 0: + return + + for hr in self.highlighted_ranges: + hr.item.editing = hr.item.description.id.item_id in ids + + async def add_context_item( + self, id: ContextItemId, query: str, prev: List[ContextItem] = None + ) -> List[ContextItem]: + raise NotImplementedError() + + async def manually_add_context_item(self, context_item: ContextItem): + full_file_content = await self.ide.readFile( + context_item.description.description + ) + self.highlighted_ranges.append( + HighlightedRangeContextItem( + rif=RangeInFileWithContents( + filepath=context_item.description.description, + range=Range.from_lines_snippet_in_file( + content=full_file_content, + snippet=context_item.content, + ), + contents=context_item.content, + ), + item=context_item, + ) + ) diff --git a/server/continuedev/plugins/context_providers/search.py b/server/continuedev/plugins/context_providers/search.py new file mode 100644 index 00000000..a36b2a0a --- /dev/null +++ b/server/continuedev/plugins/context_providers/search.py @@ -0,0 +1,90 @@ +from typing import List + +from pydantic import Field +from ripgrepy import Ripgrepy + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.util.logging import logger +from ...libs.util.ripgrep import get_rg_path +from .util import remove_meilisearch_disallowed_chars + + +class SearchContextProvider(ContextProvider): + """Type '@search' to reference the results of codebase search, just like the results you would get from VS Code search.""" + + title = "search" + display_title = "Search" + description = "Search the workspace for all matches of an exact string (e.g. '@search console.log')" + dynamic = True + requires_query = True + + _SEARCH_CONTEXT_ITEM_ID = "search" + + workspace_dir: str = Field(None, description="The workspace directory to search") + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Search", + description="Search the workspace for all matches of an exact string (e.g. '@search console.log')", + id=ContextItemId( + provider_title=self.title, item_id=self._SEARCH_CONTEXT_ITEM_ID + ), + ), + ) + + async def _search(self, query: str) -> str: + rg = Ripgrepy(query, self.workspace_dir, rg_path=get_rg_path()) + results = rg.I().context(2).run() + return f"Search results in workspace for '{query}':\n\n{results}" + + # Custom display below - TODO + + # Gather results per file + file_to_matches = {} + for result in results: + if result["type"] == "match": + data = result["data"] + filepath = data["path"]["text"] + if filepath not in file_to_matches: + file_to_matches[filepath] = [] + + line_num_and_line = f"{data['line_number']}: {data['lines']['text']}" + file_to_matches[filepath].append(line_num_and_line) + + # Format results + content = f"Search results in workspace for '{query}':\n\n" + for filepath, matches in file_to_matches.items(): + content += f"{filepath}\n" + for match in matches: + content += f"{match}\n" + content += "\n" + + return content + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + + try: + Ripgrepy("", workspace_dir, rg_path=get_rg_path()) + except Exception as e: + logger.warning(f"Failed to initialize ripgrepy: {e}") + return [] + + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + query = query.lstrip("search ") + results = await self._search(query) + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = results + ctx_item.description.name = f"Search: '{query}'" + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) + return ctx_item diff --git a/server/continuedev/plugins/context_providers/terminal.py b/server/continuedev/plugins/context_providers/terminal.py new file mode 100644 index 00000000..c63239e4 --- /dev/null +++ b/server/continuedev/plugins/context_providers/terminal.py @@ -0,0 +1,49 @@ +from typing import Any, Coroutine, List + +from pydantic import Field + +from ...core.context import ContextProvider +from ...core.main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId + + +class TerminalContextProvider(ContextProvider): + """Type '@terminal' to reference the contents of your IDE's terminal.""" + + title = "terminal" + display_title = "Terminal" + description = "Reference the contents of the terminal" + dynamic = True + + get_last_n_commands: int = Field( + 3, description="The number of previous commands to reference" + ) + + def _terminal_context_item(self, content: str = ""): + return ContextItem( + content=content, + description=ContextItemDescription( + name="Terminal", + description="Reference the contents of the VS Code terminal", + id=ContextItemId(provider_title=self.title, item_id=self.title), + ), + ) + + async def get_chat_messages(self) -> Coroutine[Any, Any, List[ChatMessage]]: + msgs = await super().get_chat_messages() + for msg in msgs: + msg.summary = msg.content[-1000:] + return msgs + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + return [self._terminal_context_item()] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + terminal_contents = await self.sdk.ide.getTerminalContents( + self.get_last_n_commands + ) + terminal_contents = terminal_contents[-5000:] + + return self._terminal_context_item(terminal_contents) diff --git a/server/continuedev/plugins/context_providers/url.py b/server/continuedev/plugins/context_providers/url.py new file mode 100644 index 00000000..1ed7c18e --- /dev/null +++ b/server/continuedev/plugins/context_providers/url.py @@ -0,0 +1,104 @@ +from typing import List + +import requests +from bs4 import BeautifulSoup +from pydantic import Field + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars + + +class URLContextProvider(ContextProvider): + """Type '@url' to reference the contents of a URL. You can either reference preset URLs, or reference one dynamically by typing '@url https://example.com'. The text contents of the page will be fetched and used as context.""" + + title = "url" + display_title = "URL" + description = "Reference the contents of a webpage" + dynamic = True + requires_query = True + + # Allows users to provide a list of preset urls + preset_urls: List[str] = Field( + [], + description="A list of preset URLs that you will be able to quickly reference by typing '@url'", + ) + + # Static items loaded from preset_urls + static_url_context_items: List[ContextItem] = [] + + # There is only a single dynamic url context item, so it has a static id + _DYNAMIC_URL_CONTEXT_ITEM_ID = "url" + + # This is a template dynamic item that will generate context item on demand + # when get item is called + @property + def DYNAMIC_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Dynamic URL", + description="Reference the contents of a webpage (e.g. '@url https://www.w3schools.com/python/python_ref_functions.asp')", + id=ContextItemId( + provider_title=self.title, item_id=self._DYNAMIC_URL_CONTEXT_ITEM_ID + ), + ), + ) + + def static_url_context_item_from_url(self, url: str) -> ContextItem: + content, title = self._get_url_text_contents_and_title(url) + return ContextItem( + content=content, + description=ContextItemDescription( + name=title, + description=f"Contents of {url}", + id=ContextItemId( + provider_title=self.title, + item_id=remove_meilisearch_disallowed_chars(url), + ), + ), + ) + + def _get_url_text_contents_and_title(self, url: str) -> (str, str): + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + title = url.replace("https://", "").replace("http://", "").replace("www.", "") + if soup.title is not None: + title = soup.title.string + return soup.get_text(), title + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.static_url_context_items = [ + self.static_url_context_item_from_url(url) for url in self.preset_urls + ] + + return [self.DYNAMIC_CONTEXT_ITEM] + self.static_url_context_items + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + # Check if the item is a static item + matching_static_item = next( + ( + item + for item in self.static_url_context_items + if item.description.id.item_id == id.item_id + ), + None, + ) + if matching_static_item: + return matching_static_item + + # Check if the item is the dynamic item + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + # Generate the dynamic item + url = query.lstrip("url ").strip() + if url is None or url == "": + return None + content, title = self._get_url_text_contents_and_title(url) + + ctx_item = self.DYNAMIC_CONTEXT_ITEM.copy() + ctx_item.content = content + ctx_item.description.name = title + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(url) + return ctx_item diff --git a/server/continuedev/plugins/context_providers/util.py b/server/continuedev/plugins/context_providers/util.py new file mode 100644 index 00000000..61bea8aa --- /dev/null +++ b/server/continuedev/plugins/context_providers/util.py @@ -0,0 +1,5 @@ +import re + + +def remove_meilisearch_disallowed_chars(id: str) -> str: + return re.sub(r"[^0-9a-zA-Z_-]", "", id) |