diff options
Diffstat (limited to 'server/continuedev/plugins/context_providers/embeddings.py')
-rw-r--r-- | server/continuedev/plugins/context_providers/embeddings.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/server/continuedev/plugins/context_providers/embeddings.py b/server/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..86cba311 --- /dev/null +++ b/server/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,81 @@ +import os +import uuid +from typing import List, Optional + +from pydantic import BaseModel + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager + + +class EmbeddingResult(BaseModel): + filename: str + content: str + + +class EmbeddingsProvider(ContextProvider): + title = "embed" + + display_title = "Embeddings Search" + description = "Search the codebase using embeddings" + dynamic = True + requires_query = True + + workspace_directory: str + + EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + + index_manager: Optional[ChromaIndexManager] = None + + class Config: + arbitrary_types_allowed = True + + @property + def index(self): + if self.index_manager is None: + self.index_manager = ChromaIndexManager(self.workspace_directory) + return self.index_manager + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Embedding Search", + description="Enter a query to embedding search codebase", + id=ContextItemId( + provider_title=self.title, item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ), + ), + ) + + async def _get_query_results(self, query: str) -> str: + results = self.index.query_codebase_index(query) + + ret = [] + for node in results.source_nodes: + resource_name = list(node.node.relationships.values())[0] + filepath = resource_name[: resource_name.index("::")] + ret.append(EmbeddingResult(filename=filepath, content=node.node.text)) + + return ret + + async def provide_context_items(self) -> List[ContextItem]: + self.index.create_codebase_index() # TODO Synchronous here is not ideal + + return [self.BASE_CONTEXT_ITEM] + + async def add_context_item(self, id: ContextItemId, query: str): + if not id.provider_title == self.title: + raise Exception("Invalid provider title for item") + + results = await self._get_query_results(query) + + for i in range(len(results)): + result = results[i] + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.description.name = os.path.basename(result.filename) + ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" + ctx_item.description.id.item_id = uuid.uuid4().hex + self.selected_items.append(ctx_item) |