summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py54
-rw-r--r--continuedev/src/continuedev/core/context.py18
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py9
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py19
-rw-r--r--continuedev/src/continuedev/server/gui.py18
6 files changed, 119 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index ded120d2..2e255198 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,3 +1,5 @@
+import json
+import os
import time
import traceback
from functools import cached_property
@@ -9,6 +11,7 @@ from pydantic import root_validator
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
+from ..libs.util.paths import getSavedContextGroupsPath
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.strings import remove_quotes_and_escapes
from ..libs.util.telemetry import posthog_logger
@@ -30,6 +33,7 @@ from ..server.ide_protocol import AbstractIdeProtocolServer
from .context import ContextManager
from .main import (
Context,
+ ContextItem,
ContinueCustomException,
FullState,
History,
@@ -111,6 +115,21 @@ class Autopilot(ContinueBaseModel):
self.history = full_state.history
self.session_info = full_state.session_info
+ # Load saved context groups
+ context_groups_file = getSavedContextGroupsPath()
+ try:
+ with open(context_groups_file, "r") as f:
+ json_ob = json.load(f)
+ for title, context_group in json_ob.items():
+ self._saved_context_groups[title] = [
+ ContextItem(**item) for item in context_group
+ ]
+ except Exception as e:
+ logger.warning(
+ f"Failed to load saved_context_groups.json: {e}. Reverting to empty list."
+ )
+ self._saved_context_groups = {}
+
self.started = True
class Config:
@@ -139,6 +158,7 @@ class Autopilot(ContinueBaseModel):
if self.context_manager is not None
else [],
session_info=self.session_info,
+ saved_context_groups=self._saved_context_groups,
)
self.full_state = full_state
return full_state
@@ -521,3 +541,37 @@ class Autopilot(ContinueBaseModel):
async def select_context_item(self, id: str, query: str):
await self.context_manager.select_context_item(id, query)
await self.update_subscribers()
+
+ _saved_context_groups: Dict[str, List[ContextItem]] = {}
+
+ async def save_context_group(self, title: str, context_items: List[ContextItem]):
+ self._saved_context_groups[title] = context_items
+ await self.update_subscribers()
+
+ # Update saved context groups
+ context_groups_file = getSavedContextGroupsPath()
+ if os.path.exists(context_groups_file):
+ with open(context_groups_file, "w") as f:
+ dict_to_save = {
+ title: [item.dict() for item in context_items]
+ for title, context_items in self._saved_context_groups.items()
+ }
+ json.dump(dict_to_save, f)
+
+ posthog_logger.capture_event(
+ "save_context_group", {"title": title, "length": len(context_items)}
+ )
+
+ async def select_context_group(self, id: str):
+ if id not in self._saved_context_groups:
+ logger.warning(f"Context group {id} not found")
+ return
+ context_group = self._saved_context_groups[id]
+ await self.context_manager.clear_context()
+ for item in context_group:
+ await self.context_manager.manually_add_context_item(item)
+ await self.update_subscribers()
+
+ posthog_logger.capture_event(
+ "select_context_group", {"title": id, "length": len(context_group)}
+ )
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index e5d6e13b..7172883f 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -121,6 +121,13 @@ class ContextProvider(BaseModel):
if new_item := await self.get_item(id, query):
self.selected_items.append(new_item)
+ async def manually_add_context_item(self, context_item: ContextItem):
+ for item in self.selected_items:
+ if item.description.id.item_id == context_item.description.id.item_id:
+ return
+
+ self.selected_items.append(context_item)
+
class ContextManager:
"""
@@ -278,6 +285,17 @@ class ContextManager:
for provider in self.context_providers.values():
await self.context_providers[provider.title].clear_context()
+ async def manually_add_context_item(self, item: ContextItem):
+ """
+ Adds the given ContextItem to the list of ContextItems.
+ """
+ if item.description.id.provider_title not in self.provider_titles:
+ return
+
+ await self.context_providers[
+ item.description.id.provider_title
+ ].manually_add_context_item(item)
+
"""
Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 53440dae..e4ee7668 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -287,6 +287,7 @@ class FullState(ContinueBaseModel):
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
session_info: Optional[SessionInfo] = None
+ saved_context_groups: Dict[str, List[ContextItem]] = {}
class ContinueSDK:
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index 93ab16db..a411c5c3 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -75,3 +75,12 @@ def getLogFilePath():
path = os.path.join(getGlobalFolderPath(), "continue.log")
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
+
+
+def getSavedContextGroupsPath():
+ path = os.path.join(getGlobalFolderPath(), "saved_context_groups.json")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ if not os.path.exists(path):
+ with open(path, "w") as f:
+ f.write("\{\}")
+ return path
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index ed293124..504764b9 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -11,6 +11,7 @@ from ...core.context import (
)
from ...core.main import ChatMessage
from ...models.filesystem import RangeInFileWithContents
+from ...models.main import Range
class HighlightedRangeContextItem(BaseModel):
@@ -257,3 +258,21 @@ class HighlightedCodeContextProvider(ContextProvider):
self, id: ContextItemId, query: str, prev: List[ContextItem] = None
) -> List[ContextItem]:
raise NotImplementedError()
+
+ async def manually_add_context_item(self, context_item: ContextItem):
+ full_file_content = await self.ide.readFile(
+ context_item.description.description
+ )
+ self.highlighted_ranges.append(
+ HighlightedRangeContextItem(
+ rif=RangeInFileWithContents(
+ filepath=context_item.description.description,
+ range=Range.from_lines_snippet_in_file(
+ content=full_file_content,
+ snippet=context_item.content,
+ ),
+ contents=context_item.content,
+ ),
+ item=context_item,
+ )
+ )
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 7497e777..a4c45a06 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -8,6 +8,7 @@ from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
+from ..core.main import ContextItem
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -104,6 +105,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.load_session(data.get("session_id", None))
elif message_type == "edit_step_at_index":
self.edit_step_at_index(data.get("user_input", ""), data["index"])
+ elif message_type == "save_context_group":
+ self.save_context_group(
+ data["title"], [ContextItem(**item) for item in data["context_items"]]
+ )
+ elif message_type == "select_context_group":
+ self.select_context_group(data["id"])
def on_main_input(self, input: str):
# Do something with user input
@@ -186,6 +193,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
posthog_logger.capture_event("load_session", {"session_id": session_id})
+ def save_context_group(self, title: str, context_items: List[ContextItem]):
+ create_async_task(
+ self.session.autopilot.save_context_group(title, context_items),
+ self.on_error,
+ )
+
+ def select_context_group(self, id: str):
+ create_async_task(
+ self.session.autopilot.select_context_group(id), self.on_error
+ )
+
@router.websocket("/ws")
async def websocket_endpoint(