summaryrefslogtreecommitdiff
path: root/server/continuedev/plugins/context_providers/dynamic.py
blob: 505676213475a9f4b409e707abffbb4985f12351 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from abc import ABC, abstractmethod
from typing import List

from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...libs.util.create_async_task import create_async_task
from .util import remove_meilisearch_disallowed_chars


class DynamicProvider(ContextProvider, ABC):
    """
    A title representing the provider
    """

    title: str
    """A name representing the provider. Probably use capitalized version of title"""

    name: str

    workspace_dir: str = None
    dynamic: bool = True

    @property
    def BASE_CONTEXT_ITEM(self):
        return ContextItem(
            content="",
            description=ContextItemDescription(
                name=self.name,
                description=self.description,
                id=ContextItemId(provider_title=self.title, item_id=self.title),
            ),
        )

    async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
        self.workspace_dir = workspace_dir
        create_async_task(self.setup())
        return [self.BASE_CONTEXT_ITEM]

    async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
        if not id.provider_title == self.title:
            raise Exception("Invalid provider title for item")

        query = query.lstrip(self.title + " ")
        results = await self.get_content(query)

        ctx_item = self.BASE_CONTEXT_ITEM.copy()
        ctx_item.content = results
        ctx_item.description.name = f"{self.name}: '{query}'"
        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query)
        return ctx_item

    @abstractmethod
    async def get_content(self, query: str) -> str:
        """Retrieve the content given the query
        (e.g. search the codebase, return search results)"""
        raise NotImplementedError

    @abstractmethod
    async def setup(self):
        """Run any setup needed (e.g. indexing the codebase)"""
        raise NotImplementedError


"""
class ExampleDynamicProvider(DynamicProvider):
    title = "example"
    name = "Example"
    description = "Example description"

    async def get_content(self, query: str) -> str:
        return f"Example content for '{query}'"

    async def setup(self):
        print("Example setup")
"""