diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-25 13:38:41 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-25 13:38:41 -0700 |
commit | e5f56308c5fd87695278682b2a36ca60df0db863 (patch) | |
tree | c7d66f5a3b56ce762bfd26033890597a07099007 /continuedev | |
parent | a55d64127a1e972d03f54a175b54eb0ad78e2b0e (diff) | |
download | sncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.tar.gz sncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.tar.bz2 sncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.zip |
fix: :bug: ssh compatibility by reading from vscode.workspace.fs
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 66 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 14 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 |
3 files changed, 53 insertions, 31 deletions
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index a748379e..9846dd3e 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -1,25 +1,20 @@ import asyncio import os -from fnmatch import fnmatch from typing import List from ...core.context import ContextProvider from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...core.sdk import ContinueSDK from .util import remove_meilisearch_disallowed_chars -MAX_SIZE_IN_BYTES = 1024 * 1024 * 1 +MAX_SIZE_IN_CHARS = 25_000 -def get_file_contents(filepath: str) -> str: +async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str: try: - filesize = os.path.getsize(filepath) - if filesize > MAX_SIZE_IN_BYTES: - return None - - with open(filepath, "r") as f: - return f.read() - except Exception: - # Some files cannot be read, e.g. binary files + return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS] + except Exception as e: + print(f"Failed to read file {filepath}: {e}") return None @@ -105,7 +100,7 @@ class FileContextProvider(ContextProvider): async def get_context_item_for_filepath( self, absolute_filepath: str ) -> ContextItem: - content = get_file_contents(absolute_filepath) + content = await get_file_contents(absolute_filepath, self.sdk) if content is None: return None @@ -128,26 +123,35 @@ class FileContextProvider(ContextProvider): ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + contents = await self.sdk.ide.listDirectoryContents(workspace_dir) + if contents is None: + return [] + absolute_filepaths: List[str] = [] - for root, dir_names, file_names in os.walk(workspace_dir): - dir_names[:] = [ - d - for d in dir_names - if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns) + for filepath in contents[:1000]: + absolute_filepaths.append(filepath) + + # for root, dir_names, file_names in os.walk(workspace_dir): + # dir_names[:] = [ + # d + # for d in dir_names + # if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns) + # ] + # for file_name in file_names: + # absolute_filepaths.append(os.path.join(root, file_name)) + + # if len(absolute_filepaths) > 1000: + # break + + # if len(absolute_filepaths) > 1000: + # break + + items = await asyncio.gather( + *[ + self.get_context_item_for_filepath(filepath) + for filepath in absolute_filepaths ] - for file_name in file_names: - absolute_filepaths.append(os.path.join(root, file_name)) - - if len(absolute_filepaths) > 1000: - break - - if len(absolute_filepaths) > 1000: - break - - items = [] - for absolute_filepath in absolute_filepaths: - item = await self.get_context_item_for_filepath(absolute_filepath) - if item is not None: - items.append(item) + ) + items = list(filter(lambda item: item is not None, items)) return items diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 610a1a48..871724db 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -120,6 +120,10 @@ class TerminalContentsResponse(BaseModel): contents: str +class ListDirectoryContentsResponse(BaseModel): + contents: List[str] + + T = TypeVar("T", bound=BaseModel) @@ -241,6 +245,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "getUserSecret", "runCommand", "getTerminalContents", + "listDirectoryContents", ]: self.sub_queue.post(message_type, data) elif message_type == "workspaceDirectory": @@ -477,6 +482,15 @@ class IdeProtocolServer(AbstractIdeProtocolServer): ) return resp.fileEdit + async def listDirectoryContents(self, directory: str) -> List[str]: + """List the contents of a directory""" + resp = await self._send_and_receive_json( + {"directory": directory}, + ListDirectoryContentsResponse, + "listDirectoryContents", + ) + return resp.contents + async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: """Apply a file edit""" backward = None diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 435c82e2..f37c1737 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -147,5 +147,9 @@ class AbstractIdeProtocolServer(ABC): def onFileSaved(self, filepath: str, contents: str): """Called when a file is saved""" + @abstractmethod + async def listDirectoryContents(self, directory: str) -> List[str]: + """List directory contents""" + workspace_directory: str unique_id: str |