summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-25 13:38:41 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-25 13:38:41 -0700
commite5f56308c5fd87695278682b2a36ca60df0db863 (patch)
treec7d66f5a3b56ce762bfd26033890597a07099007 /continuedev
parenta55d64127a1e972d03f54a175b54eb0ad78e2b0e (diff)
downloadsncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.tar.gz
sncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.tar.bz2
sncontinue-e5f56308c5fd87695278682b2a36ca60df0db863.zip
fix: :bug: ssh compatibility by reading from vscode.workspace.fs
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py66
-rw-r--r--continuedev/src/continuedev/server/ide.py14
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py4
3 files changed, 53 insertions, 31 deletions
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index a748379e..9846dd3e 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -1,25 +1,20 @@
import asyncio
import os
-from fnmatch import fnmatch
from typing import List
from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
+from ...core.sdk import ContinueSDK
from .util import remove_meilisearch_disallowed_chars
-MAX_SIZE_IN_BYTES = 1024 * 1024 * 1
+MAX_SIZE_IN_CHARS = 25_000
-def get_file_contents(filepath: str) -> str:
+async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str:
try:
- filesize = os.path.getsize(filepath)
- if filesize > MAX_SIZE_IN_BYTES:
- return None
-
- with open(filepath, "r") as f:
- return f.read()
- except Exception:
- # Some files cannot be read, e.g. binary files
+ return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS]
+ except Exception as e:
+ print(f"Failed to read file {filepath}: {e}")
return None
@@ -105,7 +100,7 @@ class FileContextProvider(ContextProvider):
async def get_context_item_for_filepath(
self, absolute_filepath: str
) -> ContextItem:
- content = get_file_contents(absolute_filepath)
+ content = await get_file_contents(absolute_filepath, self.sdk)
if content is None:
return None
@@ -128,26 +123,35 @@ class FileContextProvider(ContextProvider):
)
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
+ contents = await self.sdk.ide.listDirectoryContents(workspace_dir)
+ if contents is None:
+ return []
+
absolute_filepaths: List[str] = []
- for root, dir_names, file_names in os.walk(workspace_dir):
- dir_names[:] = [
- d
- for d in dir_names
- if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns)
+ for filepath in contents[:1000]:
+ absolute_filepaths.append(filepath)
+
+ # for root, dir_names, file_names in os.walk(workspace_dir):
+ # dir_names[:] = [
+ # d
+ # for d in dir_names
+ # if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns)
+ # ]
+ # for file_name in file_names:
+ # absolute_filepaths.append(os.path.join(root, file_name))
+
+ # if len(absolute_filepaths) > 1000:
+ # break
+
+ # if len(absolute_filepaths) > 1000:
+ # break
+
+ items = await asyncio.gather(
+ *[
+ self.get_context_item_for_filepath(filepath)
+ for filepath in absolute_filepaths
]
- for file_name in file_names:
- absolute_filepaths.append(os.path.join(root, file_name))
-
- if len(absolute_filepaths) > 1000:
- break
-
- if len(absolute_filepaths) > 1000:
- break
-
- items = []
- for absolute_filepath in absolute_filepaths:
- item = await self.get_context_item_for_filepath(absolute_filepath)
- if item is not None:
- items.append(item)
+ )
+ items = list(filter(lambda item: item is not None, items))
return items
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 610a1a48..871724db 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -120,6 +120,10 @@ class TerminalContentsResponse(BaseModel):
contents: str
+class ListDirectoryContentsResponse(BaseModel):
+ contents: List[str]
+
+
T = TypeVar("T", bound=BaseModel)
@@ -241,6 +245,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"getUserSecret",
"runCommand",
"getTerminalContents",
+ "listDirectoryContents",
]:
self.sub_queue.post(message_type, data)
elif message_type == "workspaceDirectory":
@@ -477,6 +482,15 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
)
return resp.fileEdit
+ async def listDirectoryContents(self, directory: str) -> List[str]:
+ """List the contents of a directory"""
+ resp = await self._send_and_receive_json(
+ {"directory": directory},
+ ListDirectoryContentsResponse,
+ "listDirectoryContents",
+ )
+ return resp.contents
+
async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff:
"""Apply a file edit"""
backward = None
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index 435c82e2..f37c1737 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -147,5 +147,9 @@ class AbstractIdeProtocolServer(ABC):
def onFileSaved(self, filepath: str, contents: str):
"""Called when a file is saved"""
+ @abstractmethod
+ async def listDirectoryContents(self, directory: str) -> List[str]:
+ """List directory contents"""
+
workspace_directory: str
unique_id: str