From d8e821e422678fd4248b472c7f3e67a32ecfefb5 Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Wed, 6 Sep 2023 20:18:49 -0700
Subject: fix: :bug: separately load ctx provs, fix filetree

---
 continuedev/src/continuedev/core/context.py        | 46 ++++++++-------
 .../plugins/context_providers/filetree.py          | 66 +++++++++++++++++-----
 continuedev/src/continuedev/server/ide.py          |  2 +-
 3 files changed, 80 insertions(+), 34 deletions(-)

(limited to 'continuedev')

diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 125ddc23..571e5dc8 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -285,8 +285,6 @@ class ContextManager:
                 await globalSearchIndex.update_filterable_attributes(["workspace_dir"])
 
                 async def load_context_provider(provider: ContextProvider):
-                    ti = time.time()
-
                     context_items = await provider.provide_context_items(workspace_dir)
                     documents = [
                         {
@@ -299,28 +297,34 @@ class ContextManager:
                         for item in context_items
                     ]
                     if len(documents) > 0:
-                        try:
-                            await asyncio.wait_for(
-                                globalSearchIndex.add_documents(documents), timeout=20
-                            )
-                        except asyncio.TimeoutError:
-                            logger.warning(
-                                f"Failed to add documents to meilisearch for context provider {provider.__class__.__name__} in 20 seconds"
-                            )
-                            return
-                        except Exception as e:
-                            logger.warning(
-                                f"Error adding documents to meilisearch for context provider {provider.__class__.__name__}: {e}"
-                            )
-                            return
+                        await globalSearchIndex.add_documents(documents)
+
+                    return len(documents)
+
+                async def safe_load(provider: ContextProvider):
+                    ti = time.time()
+                    try:
+                        num_documents = await asyncio.wait_for(
+                            load_context_provider(provider), timeout=20
+                        )
+                    except asyncio.TimeoutError:
+                        logger.warning(
+                            f"Failed to add documents to meilisearch for context provider {provider.__class__.__name__} in 20 seconds"
+                        )
+                        return
+                    except Exception as e:
+                        logger.warning(
+                            f"Error adding documents to meilisearch for context provider {provider.__class__.__name__}: {e}"
+                        )
+                        return
 
                     tf = time.time()
                     logger.debug(
-                        f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}"
+                        f"Loaded {num_documents} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}"
                     )
 
                 tasks = [
-                    load_context_provider(provider)
+                    safe_load(provider)
                     for _, provider in self.context_providers.items()
                 ]
                 await asyncio.wait_for(asyncio.gather(*tasks), timeout=20)
@@ -330,9 +334,11 @@ class ContextManager:
             if should_retry:
                 await restart_meilisearch()
                 try:
-                    asyncio.wait_for(await poll_meilisearch_running(), timeout=20)
+                    await asyncio.wait_for(poll_meilisearch_running(), timeout=20)
                 except asyncio.TimeoutError:
-                    logger.warning("Meilisearch did not restart in less than 20 seconds. Stopping polling.")
+                    logger.warning(
+                        "Meilisearch did not restart in less than 20 seconds. Stopping polling."
+                    )
                 await self.load_index(workspace_dir, False)
 
     async def select_context_item(self, id: str, query: str):
diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py
index ea86f214..959a0a66 100644
--- a/continuedev/src/continuedev/plugins/context_providers/filetree.py
+++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py
@@ -1,31 +1,71 @@
-import os
 from typing import List
 
+from pydantic import BaseModel
+
 from ...core.context import ContextProvider
 from ...core.main import ContextItem, ContextItemDescription, ContextItemId
 
 
-def format_file_tree(startpath) -> str:
+class Directory(BaseModel):
+    name: str
+    files: List[str]
+    directories: List["Directory"]
+
+
+def format_file_tree(tree: Directory, indentation: str = "") -> str:
     result = ""
-    for root, dirs, files in os.walk(startpath):
-        level = root.replace(startpath, "").count(os.sep)
-        indent = " " * 4 * (level)
-        result += "{}{}/".format(indent, os.path.basename(root)) + "\n"
-        subindent = " " * 4 * (level + 1)
-        for f in files:
-            result += "{}{}".format(subindent, f) + "\n"
+    for file in tree.files:
+        result += f"{indentation}{file}\n"
+
+    for directory in tree.directories:
+        result += f"{indentation}{directory.name}/\n"
+        result += format_file_tree(directory, indentation + "  ")
 
     return result
 
 
+def split_path(path: str, with_root=None) -> List[str]:
+    parts = path.split("/") if "/" in path else path.split("\\")
+    if with_root is not None:
+        root_parts = split_path(with_root)
+        parts = parts[len(root_parts) - 1 :]
+
+    return parts
+
+
 class FileTreeContextProvider(ContextProvider):
     title = "tree"
 
     workspace_dir: str = None
 
-    def _filetree_context_item(self):
+    async def _get_file_tree(self, directory: str) -> str:
+        contents = await self.sdk.ide.listDirectoryContents(directory, recursive=True)
+
+        tree = Directory(
+            name=split_path(self.workspace_dir)[-1], files=[], directories=[]
+        )
+
+        for file in contents:
+            parts = split_path(file, with_root=self.workspace_dir)
+
+            current_tree = tree
+            for part in parts[:-1]:
+                if part not in [d.name for d in current_tree.directories]:
+                    current_tree.directories.append(
+                        Directory(name=part, files=[], directories=[])
+                    )
+
+                current_tree = [d for d in current_tree.directories if d.name == part][
+                    0
+                ]
+
+            current_tree.files.append(parts[-1])
+
+        return format_file_tree(tree)
+
+    async def _filetree_context_item(self):
         return ContextItem(
-            content=format_file_tree(self.workspace_dir),
+            content=await self._get_file_tree(self.workspace_dir),
             description=ContextItemDescription(
                 name="File Tree",
                 description="Add a formatted file tree of this directory to the context",
@@ -35,10 +75,10 @@ class FileTreeContextProvider(ContextProvider):
 
     async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
         self.workspace_dir = workspace_dir
-        return [self._filetree_context_item()]
+        return [await self._filetree_context_item()]
 
     async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
         if not id.item_id == self.title:
             raise Exception("Invalid item id")
 
-        return self._filetree_context_item()
+        return await self._filetree_context_item()
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 5d1d897e..bdb5817d 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -196,7 +196,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
                     f"Tried to send message, but websocket is disconnected: {message_type}"
                 )
                 return
-            logger.debug(f"Sending IDE message: {message_type}")
+            # logger.debug(f"Sending IDE message: {message_type}")
             await self.websocket.send_json({"messageType": message_type, "data": data})
         except RuntimeError as e:
             logger.warning(f"Error sending IDE message, websocket probably closed: {e}")
-- 
cgit v1.2.3-70-g09d2