blob: 42d1f7548e6d35241fb40bfe38ba04fa89d045a6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
import os
from typing import List, Optional
import uuid
from pydantic import BaseModel
from ...core.main import ContextItemId
from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...libs.chroma.query import ChromaIndexManager
from .util import remove_meilisearch_disallowed_chars
class EmbeddingResult(BaseModel):
filename: str
content: str
class EmbeddingsProvider(ContextProvider):
title = "embed"
workspace_directory: str
EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"
index_manager: Optional[ChromaIndexManager] = None
class Config:
arbitrary_types_allowed = True
@property
def index(self):
if self.index_manager is None:
self.index_manager = ChromaIndexManager(self.workspace_directory)
return self.index_manager
@property
def BASE_CONTEXT_ITEM(self):
return ContextItem(
content="",
description=ContextItemDescription(
name="Embedding Search",
description="Enter a query to embedding search codebase",
id=ContextItemId(
provider_title=self.title,
item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID
)
)
)
async def _get_query_results(self, query: str) -> str:
results = self.index.query_codebase_index(query)
ret = []
for node in results.source_nodes:
resource_name = list(node.node.relationships.values())[0]
filepath = resource_name[:resource_name.index("::")]
ret.append(EmbeddingResult(
filename=filepath, content=node.node.text))
return ret
async def provide_context_items(self) -> List[ContextItem]:
self.index.create_codebase_index() # TODO Synchronous here is not ideal
return [self.BASE_CONTEXT_ITEM]
async def add_context_item(self, id: ContextItemId, query: str):
if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
raise Exception("Invalid item id")
results = await self._get_query_results(query)
for i in range(len(results)):
result = results[i]
ctx_item = self.BASE_CONTEXT_ITEM.copy()
ctx_item.description.name = os.path.basename(result.filename)
ctx_item.content = f"{result.filename}\n```\n{result.content}\n```"
ctx_item.description.id.item_id = uuid.uuid4().hex
self.selected_items.append(ctx_item)
|