summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/plugins/context_providers/embeddings.py
blob: 42d1f7548e6d35241fb40bfe38ba04fa89d045a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from typing import List, Optional
import uuid
from pydantic import BaseModel

from ...core.main import ContextItemId
from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...libs.chroma.query import ChromaIndexManager
from .util import remove_meilisearch_disallowed_chars


class EmbeddingResult(BaseModel):
    filename: str
    content: str


class EmbeddingsProvider(ContextProvider):
    title = "embed"

    workspace_directory: str

    EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"

    index_manager: Optional[ChromaIndexManager] = None

    class Config:
        arbitrary_types_allowed = True

    @property
    def index(self):
        if self.index_manager is None:
            self.index_manager = ChromaIndexManager(self.workspace_directory)
        return self.index_manager

    @property
    def BASE_CONTEXT_ITEM(self):
        return ContextItem(
            content="",
            description=ContextItemDescription(
                name="Embedding Search",
                description="Enter a query to embedding search codebase",
                id=ContextItemId(
                    provider_title=self.title,
                    item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID
                )
            )
        )

    async def _get_query_results(self, query: str) -> str:
        results = self.index.query_codebase_index(query)

        ret = []
        for node in results.source_nodes:
            resource_name = list(node.node.relationships.values())[0]
            filepath = resource_name[:resource_name.index("::")]
            ret.append(EmbeddingResult(
                filename=filepath, content=node.node.text))

        return ret

    async def provide_context_items(self) -> List[ContextItem]:
        self.index.create_codebase_index()  # TODO Synchronous here is not ideal

        return [self.BASE_CONTEXT_ITEM]

    async def add_context_item(self, id: ContextItemId, query: str):
        if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
            raise Exception("Invalid item id")

        results = await self._get_query_results(query)

        for i in range(len(results)):
            result = results[i]
            ctx_item = self.BASE_CONTEXT_ITEM.copy()
            ctx_item.description.name = os.path.basename(result.filename)
            ctx_item.content = f"{result.filename}\n```\n{result.content}\n```"
            ctx_item.description.id.item_id = uuid.uuid4().hex
            self.selected_items.append(ctx_item)