summaryrefslogtreecommitdiff
path: root/server/continuedev/plugins/context_providers/embeddings.py
blob: 86cba3112ef7d04d5cf9af0df5bf285cdaf9cc47 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import uuid
from typing import List, Optional

from pydantic import BaseModel

from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...libs.chroma.query import ChromaIndexManager


class EmbeddingResult(BaseModel):
    filename: str
    content: str


class EmbeddingsProvider(ContextProvider):
    title = "embed"

    display_title = "Embeddings Search"
    description = "Search the codebase using embeddings"
    dynamic = True
    requires_query = True

    workspace_directory: str

    EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"

    index_manager: Optional[ChromaIndexManager] = None

    class Config:
        arbitrary_types_allowed = True

    @property
    def index(self):
        if self.index_manager is None:
            self.index_manager = ChromaIndexManager(self.workspace_directory)
        return self.index_manager

    @property
    def BASE_CONTEXT_ITEM(self):
        return ContextItem(
            content="",
            description=ContextItemDescription(
                name="Embedding Search",
                description="Enter a query to embedding search codebase",
                id=ContextItemId(
                    provider_title=self.title, item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID
                ),
            ),
        )

    async def _get_query_results(self, query: str) -> str:
        results = self.index.query_codebase_index(query)

        ret = []
        for node in results.source_nodes:
            resource_name = list(node.node.relationships.values())[0]
            filepath = resource_name[: resource_name.index("::")]
            ret.append(EmbeddingResult(filename=filepath, content=node.node.text))

        return ret

    async def provide_context_items(self) -> List[ContextItem]:
        self.index.create_codebase_index()  # TODO Synchronous here is not ideal

        return [self.BASE_CONTEXT_ITEM]

    async def add_context_item(self, id: ContextItemId, query: str):
        if not id.provider_title == self.title:
            raise Exception("Invalid provider title for item")

        results = await self._get_query_results(query)

        for i in range(len(results)):
            result = results[i]
            ctx_item = self.BASE_CONTEXT_ITEM.copy()
            ctx_item.description.name = os.path.basename(result.filename)
            ctx_item.content = f"{result.filename}\n```\n{result.content}\n```"
            ctx_item.description.id.item_id = uuid.uuid4().hex
            self.selected_items.append(ctx_item)