diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 54 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 18 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/highlighted_code.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 18 |
6 files changed, 119 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index ded120d2..2e255198 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,3 +1,5 @@ +import json +import os import time import traceback from functools import cached_property @@ -9,6 +11,7 @@ from pydantic import root_validator from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger +from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes from ..libs.util.telemetry import posthog_logger @@ -30,6 +33,7 @@ from ..server.ide_protocol import AbstractIdeProtocolServer from .context import ContextManager from .main import ( Context, + ContextItem, ContinueCustomException, FullState, History, @@ -111,6 +115,21 @@ class Autopilot(ContinueBaseModel): self.history = full_state.history self.session_info = full_state.session_info + # Load saved context groups + context_groups_file = getSavedContextGroupsPath() + try: + with open(context_groups_file, "r") as f: + json_ob = json.load(f) + for title, context_group in json_ob.items(): + self._saved_context_groups[title] = [ + ContextItem(**item) for item in context_group + ] + except Exception as e: + logger.warning( + f"Failed to load saved_context_groups.json: {e}. Reverting to empty list." + ) + self._saved_context_groups = {} + self.started = True class Config: @@ -139,6 +158,7 @@ class Autopilot(ContinueBaseModel): if self.context_manager is not None else [], session_info=self.session_info, + saved_context_groups=self._saved_context_groups, ) self.full_state = full_state return full_state @@ -521,3 +541,37 @@ class Autopilot(ContinueBaseModel): async def select_context_item(self, id: str, query: str): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + + _saved_context_groups: Dict[str, List[ContextItem]] = {} + + async def save_context_group(self, title: str, context_items: List[ContextItem]): + self._saved_context_groups[title] = context_items + await self.update_subscribers() + + # Update saved context groups + context_groups_file = getSavedContextGroupsPath() + if os.path.exists(context_groups_file): + with open(context_groups_file, "w") as f: + dict_to_save = { + title: [item.dict() for item in context_items] + for title, context_items in self._saved_context_groups.items() + } + json.dump(dict_to_save, f) + + posthog_logger.capture_event( + "save_context_group", {"title": title, "length": len(context_items)} + ) + + async def select_context_group(self, id: str): + if id not in self._saved_context_groups: + logger.warning(f"Context group {id} not found") + return + context_group = self._saved_context_groups[id] + await self.context_manager.clear_context() + for item in context_group: + await self.context_manager.manually_add_context_item(item) + await self.update_subscribers() + + posthog_logger.capture_event( + "select_context_group", {"title": id, "length": len(context_group)} + ) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index e5d6e13b..7172883f 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -121,6 +121,13 @@ class ContextProvider(BaseModel): if new_item := await self.get_item(id, query): self.selected_items.append(new_item) + async def manually_add_context_item(self, context_item: ContextItem): + for item in self.selected_items: + if item.description.id.item_id == context_item.description.id.item_id: + return + + self.selected_items.append(context_item) + class ContextManager: """ @@ -278,6 +285,17 @@ class ContextManager: for provider in self.context_providers.values(): await self.context_providers[provider.title].clear_context() + async def manually_add_context_item(self, item: ContextItem): + """ + Adds the given ContextItem to the list of ContextItems. + """ + if item.description.id.provider_title not in self.provider_titles: + return + + await self.context_providers[ + item.description.id.provider_title + ].manually_add_context_item(item) + """ Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 53440dae..e4ee7668 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -287,6 +287,7 @@ class FullState(ContinueBaseModel): adding_highlighted_code: bool selected_context_items: List[ContextItem] session_info: Optional[SessionInfo] = None + saved_context_groups: Dict[str, List[ContextItem]] = {} class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 93ab16db..a411c5c3 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -75,3 +75,12 @@ def getLogFilePath(): path = os.path.join(getGlobalFolderPath(), "continue.log") os.makedirs(os.path.dirname(path), exist_ok=True) return path + + +def getSavedContextGroupsPath(): + path = os.path.join(getGlobalFolderPath(), "saved_context_groups.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, "w") as f: + f.write("\{\}") + return path diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index ed293124..504764b9 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -11,6 +11,7 @@ from ...core.context import ( ) from ...core.main import ChatMessage from ...models.filesystem import RangeInFileWithContents +from ...models.main import Range class HighlightedRangeContextItem(BaseModel): @@ -257,3 +258,21 @@ class HighlightedCodeContextProvider(ContextProvider): self, id: ContextItemId, query: str, prev: List[ContextItem] = None ) -> List[ContextItem]: raise NotImplementedError() + + async def manually_add_context_item(self, context_item: ContextItem): + full_file_content = await self.ide.readFile( + context_item.description.description + ) + self.highlighted_ranges.append( + HighlightedRangeContextItem( + rif=RangeInFileWithContents( + filepath=context_item.description.description, + range=Range.from_lines_snippet_in_file( + content=full_file_content, + snippet=context_item.content, + ), + contents=context_item.content, + ), + item=context_item, + ) + ) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 7497e777..a4c45a06 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server +from ..core.main import ContextItem from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger from ..libs.util.queue import AsyncSubscriptionQueue @@ -104,6 +105,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": self.edit_step_at_index(data.get("user_input", ""), data["index"]) + elif message_type == "save_context_group": + self.save_context_group( + data["title"], [ContextItem(**item) for item in data["context_items"]] + ) + elif message_type == "select_context_group": + self.select_context_group(data["id"]) def on_main_input(self, input: str): # Do something with user input @@ -186,6 +193,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer): posthog_logger.capture_event("load_session", {"session_id": session_id}) + def save_context_group(self, title: str, context_items: List[ContextItem]): + create_async_task( + self.session.autopilot.save_context_group(title, context_items), + self.on_error, + ) + + def select_context_group(self, id: str): + create_async_task( + self.session.autopilot.select_context_group(id), self.on_error + ) + @router.websocket("/ws") async def websocket_endpoint( |