summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core/context.py
blob: 3f5f6fd3758549a23b9bc3779c5c3b83bb9a4f3b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

from abc import abstractmethod
from typing import Dict, List
from meilisearch_python_async import Client
from pydantic import BaseModel


from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
from ..server.meilisearch_server import check_meilisearch_running
from ..libs.util.logging import logger

SEARCH_INDEX_NAME = "continue_context_items"


class ContextProvider(BaseModel):
    """
    The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
    When you type '@', the context provider will be asked to populate a list of options.
    These options will be updated on each keystroke.
    When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object).
    """

    title: str

    selected_items: List[ContextItem] = []

    async def get_selected_items(self) -> List[ContextItem]:
        """
        Returns all of the selected ContextItems.

        Default implementation simply returns self.selected_items.

        Other implementations may add an async processing step.
        """
        return self.selected_items

    @abstractmethod
    async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
        """
        Provide documents for search index. This is run on startup.

        This is the only method that must be implemented.
        """

    async def get_chat_messages(self) -> List[ChatMessage]:
        """
        Returns all of the chat messages for the context provider.

        Default implementation has a string template.
        """
        return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()]

    async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
        """
        Returns the ContextItem with the given id.

        Default implementation uses the search index to get the item.
        """
        async with Client('http://localhost:7700') as search_client:
            try:
                result = await search_client.index(
                    SEARCH_INDEX_NAME).get_document(id.to_string())
                return ContextItem(
                    description=ContextItemDescription(
                        name=result["name"],
                        description=result["description"],
                        id=id
                    ),
                    content=result["content"]
                )
            except Exception as e:
                logger.warning(
                    f"Error while retrieving document from meilisearch: {e}")

            return None

    async def delete_context_with_ids(self, ids: List[ContextItemId]):
        """
        Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.

        Default implementation simply deletes those with the given ids.
        """
        id_strings = {id.to_string() for id in ids}
        self.selected_items = list(
            filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items))

    async def clear_context(self):
        """
        Clears all context.

        Default implementation simply clears the selected items.
        """
        self.selected_items = []

    async def add_context_item(self, id: ContextItemId, query: str):
        """
        Adds the given ContextItem to the list of ContextItems.

        Default implementation simply appends the item, not allowing duplicates.

        This method also allows you not to have to load all of the information until an item is selected.
        """

        # Don't add duplicate context
        for item in self.selected_items:
            if item.description.id.item_id == id.item_id:
                return

        if new_item := await self.get_item(id, query):
            self.selected_items.append(new_item)


class ContextManager:
    """
    The context manager is responsible for storing the context to be passed to the LLM, including
    - ContextItems (highlighted code, GitHub Issues, etc.)
    - ChatMessages in the history
    - System Message
    - Functions

    It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
    """

    async def get_selected_items(self) -> List[ContextItem]:
        """
        Returns all of the selected ContextItems.
        """
        return sum([await provider.get_selected_items() for provider in self.context_providers.values()], [])

    async def get_chat_messages(self) -> List[ChatMessage]:
        """
        Returns chat messages from each provider.
        """
        return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])

    def __init__(self):
        self.context_providers = {}
        self.provider_titles = set()

    async def start(self, context_providers: List[ContextProvider]):
        """
        Starts the context manager.
        """
        self.context_providers = {
            prov.title: prov for prov in context_providers}
        self.provider_titles = {
            provider.title for provider in context_providers}

        async with Client('http://localhost:7700') as search_client:
            meilisearch_running = True
            try:

                health = await search_client.health()
                if not health.status == "available":
                    meilisearch_running = False
            except:
                meilisearch_running = False

            if not meilisearch_running:
                logger.warning(
                    "MeiliSearch not running, avoiding any dependent context providers")
                self.context_providers = list(
                    filter(lambda cp: cp.title == "code", self.context_providers))

    async def load_index(self, workspace_dir: str):
        for _, provider in self.context_providers.items():
            context_items = await provider.provide_context_items(workspace_dir)
            documents = [
                {
                    "id": item.description.id.to_string(),
                    "name": item.description.name,
                    "description": item.description.description,
                    "content": item.content
                }
                for item in context_items
            ]
            if len(documents) > 0:
                try:
                    async with Client('http://localhost:7700') as search_client:
                        await search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
                except Exception as e:
                    logger.debug(f"Error loading meilisearch index: {e}")

    async def select_context_item(self, id: str, query: str):
        """
        Selects the ContextItem with the given id.
        """
        id: ContextItemId = ContextItemId.from_string(id)
        if id.provider_title not in self.provider_titles:
            raise ValueError(
                f"Context provider with title {id.provider_title} not found")

        await self.context_providers[id.provider_title].add_context_item(id, query)

    async def delete_context_with_ids(self, ids: List[str]):
        """
        Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
        """

        # Group by provider title
        provider_title_to_ids: Dict[str, List[ContextItemId]] = {}
        for id in ids:
            id: ContextItemId = ContextItemId.from_string(id)
            if id.provider_title not in provider_title_to_ids:
                provider_title_to_ids[id.provider_title] = []
            provider_title_to_ids[id.provider_title].append(id)

        # Recalculate context for each updated provider
        for provider_title, ids in provider_title_to_ids.items():
            await self.context_providers[provider_title].delete_context_with_ids(ids)

    async def clear_context(self):
        """
        Clears all context.
        """
        for provider in self.context_providers.values():
            await self.context_providers[provider.title].clear_context()


"""
Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the
same format of prompts so you don't have to redo all of this logic.
"""