import json import os import time import traceback import uuid from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional import redbaron from aiohttp import ClientPayloadError from pydantic import root_validator from ..libs.llm.prompts.chat import template_alpaca_messages from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.edit_config import edit_config_property from ..libs.util.logging import logger from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes from ..libs.util.telemetry import posthog_logger from ..libs.util.traceback.traceback_parsers import ( get_javascript_traceback, get_python_traceback, ) from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from ..models.main import ContinueBaseModel from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..plugins.policies.default import DefaultPolicy from ..plugins.steps.on_traceback import DefaultOnTracebackStep from ..server.ide_protocol import AbstractIdeProtocolServer from ..server.meilisearch_server import get_meilisearch_url, stop_meilisearch from .config import ContinueConfig from .context import ContextManager from .main import ( Context, ContextItem, ContinueCustomException, FullState, History, HistoryNode, Policy, SessionInfo, Step, ) from .observation import InternalErrorObservation, Observation from .sdk import ContinueSDK from .steps import DisplayErrorStep, ManualEditStep, ReversibleStep, UserInputStep def get_error_title(e: Exception) -> str: if isinstance(e, ClientPayloadError): return "The request failed. Please try again." elif "rate_limit_ip_middleware" in e.__str__(): return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings." elif e.__str__().startswith("Cannot connect to host"): return ( "The request failed. Please check your internet connection and try again." ) return e.__str__() or e.__repr__() class Autopilot(ContinueBaseModel): ide: AbstractIdeProtocolServer policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Optional[FullState] = None session_info: Optional[SessionInfo] = None context_manager: ContextManager = ContextManager() continue_sdk: ContinueSDK = None _on_update_callbacks: List[Callable[[FullState], None]] = [] _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() started: bool = False async def load( self, config: Optional[ContinueConfig] = None, only_reloading: bool = False ): self.continue_sdk = await ContinueSDK.create(self, config=config) if override_policy := self.continue_sdk.config.policy_override: self.policy = override_policy # Load documents into the search index logger.debug("Starting context manager") await self.context_manager.start( self.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=self.ide), FileContextProvider(workspace_dir=self.ide.workspace_directory), ], self.continue_sdk, only_reloading=only_reloading, ) async def start( self, full_state: Optional[FullState] = None, config: Optional[ContinueConfig] = None, ): await self.load(config=config, only_reloading=False) if full_state is not None: self.history = full_state.history self.session_info = full_state.session_info # Load saved context groups context_groups_file = getSavedContextGroupsPath() try: with open(context_groups_file, "r") as f: json_ob = json.load(f) for title, context_group in json_ob.items(): self._saved_context_groups[title] = [ ContextItem(**item) for item in context_group ] except Exception as e: logger.warning( f"Failed to load saved_context_groups.json: {e}. Reverting to empty list." ) self._saved_context_groups = {} self.started = True async def reload_config(self): await self.load(config=None, only_reloading=True) await self.update_subscribers() async def cleanup(self): stop_meilisearch() class Config: arbitrary_types_allowed = True keep_untouched = (cached_property,) @root_validator(pre=True) def fill_in_values(cls, values): full_state: FullState = values.get("full_state") if full_state is not None: values["history"] = full_state.history return values async def get_full_state(self) -> FullState: full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ "code" ].adding_highlighted_code if "code" in self.context_manager.context_providers else False, selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], session_info=self.session_info, config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, context_providers=self.context_manager.get_provider_descriptions(), meilisearch_url=get_meilisearch_url(), ) self.full_state = full_state return full_state def get_available_slash_commands(self) -> List[Dict]: custom_commands = ( list( map( lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.custom_commands, ) ) or [] ) slash_commands = ( list( map( lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands, ) ) or [] ) cmds = custom_commands + slash_commands cmds.sort(key=lambda x: x["name"] == "edit", reverse=True) return cmds async def clear_history(self): # Reset history self.history = History.from_empty() self._main_user_input_queue = [] self._active = False # Clear context # await self.context_manager.clear_context() await self.update_subscribers() def on_update(self, callback: Coroutine["FullState", None, None]): """Subscribe to changes to state""" self._on_update_callbacks.append(callback) async def update_subscribers(self): full_state = await self.get_full_state() for callback in self._on_update_callbacks: await callback(full_state) def give_user_input(self, input: str, index: int): self._user_input_queue.post(str(index), input) async def wait_for_user_input(self) -> str: self._active = False await self.update_subscribers() user_input = await self._user_input_queue.get(str(self.history.current_index)) self._active = True await self.update_subscribers() return user_input _manual_edits_buffer: List[FileEditWithFullContents] = [] async def reverse_to_index(self, index: int): try: while self.history.get_current_index() >= index: current_step = self.history.get_current().step self.history.step_back() if issubclass(current_step.__class__, ReversibleStep): await current_step.reverse(self.continue_sdk) await self.update_subscribers() except Exception as e: logger.debug(e) def handle_manual_edits(self, edits: List[FileEditWithFullContents]): for edit in edits: self._manual_edits_buffer.append(edit) # TODO: You're storing a lot of unnecessary data here. Can compress into EditDiffs on the spot, and merge. # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) # Note that this is being overridden to do nothing in DemoAgent async def handle_command_output(self, output: str): get_traceback_funcs = [get_python_traceback, get_javascript_traceback] for get_tb_func in get_traceback_funcs: traceback = get_tb_func(output) if ( traceback is not None and self.continue_sdk.config.on_traceback is not None ): step = self.continue_sdk.config.on_traceback(output=output) await self._run_singular_step(step) async def handle_debug_terminal(self, content: str): """Run the debug terminal step""" # step = self.continue_sdk.config.on_traceback(output=content) step = DefaultOnTracebackStep(output=content) await self._run_singular_step(step) async def handle_highlighted_code( self, range_in_files: List[RangeInFileWithContents], edit: Optional[bool] = False, ): if "code" not in self.context_manager.context_providers: return # Add to context manager await self.context_manager.context_providers["code"].handle_highlighted_code( range_in_files, edit ) await self.update_subscribers() _step_depth: int = 0 async def retry_at_index(self, index: int): self.history.timeline[index].step.hide = True self._retry_queue.post(str(index), None) async def delete_at_index(self, index: int): if not self.history.timeline[index].active: self.history.timeline[index].step.hide = True self.history.timeline[index].deleted = True self.history.timeline[index].active = False await self.update_subscribers() async def edit_step_at_index(self, user_input: str, index: int): node_to_rerun = self.history.timeline[index].copy() step_to_rerun = node_to_rerun.step step_to_rerun.user_input = user_input step_to_rerun.description = user_input # Halt the agent's currently running jobs (delete them) while len(self.history.timeline) > index: # Remove from timeline node_to_delete = self.history.timeline.pop() # Delete so it is stopped if in the middle of running node_to_delete.deleted = True self.history.current_index = index - 1 # Set the context to the context used by that step await self.context_manager.clear_context() for context_item in node_to_rerun.context_used: await self.context_manager.manually_add_context_item(context_item) await self.update_subscribers() # Rerun from the current step await self.run_from_step(step_to_rerun) async def delete_context_with_ids( self, ids: List[str], index: Optional[int] = None ): if index is None: await self.context_manager.delete_context_with_ids(ids) else: self.history.timeline[index].context_used = list( filter( lambda item: item.description.id.to_string() not in ids, self.history.timeline[index].context_used, ) ) await self.update_subscribers() async def toggle_adding_highlighted_code(self): if "code" not in self.context_manager.context_providers: return self.context_manager.context_providers[ "code" ].adding_highlighted_code = not self.context_manager.context_providers[ "code" ].adding_highlighted_code await self.update_subscribers() async def set_editing_at_ids(self, ids: List[str]): if "code" not in self.context_manager.context_providers: return await self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() async def _run_singular_step( self, step: "Step", is_future_step: bool = False ) -> Coroutine[Observation, None, None]: # Allow config to set disallowed steps if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: return None # If a parent step is deleted/cancelled, don't run this step # TODO: This was problematic because when running a step after deleting one, it seemed to think that was the parent # last_depth = self._step_depth # i = self.history.current_index # while i >= 0 and self.history.timeline[i].depth == last_depth - 1: # if self.history.timeline[i].deleted: # return None # last_depth = self.history.timeline[i].depth # i -= 1 # Log the context and step to dev data context_used = await self.context_manager.get_selected_items() posthog_logger.capture_event( "step run", {"step_name": step.name, "params": step.dict()} ) step_id = uuid.uuid4().hex dev_data_logger.capture( "step_run", {"step_name": step.name, "params": step.dict(), "step_id": step_id}, ) dev_data_logger.capture( "context_used", { "context": list( map( lambda item: item.dict(), context_used, ) ), "step_id": step_id, }, ) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep if len(self._manual_edits_buffer) > 0: manualEditsStep = ManualEditStep.from_sequence( self._manual_edits_buffer ) self._manual_edits_buffer = [] await self._run_singular_step(manualEditsStep) # Update history - do this first so we get top-first tree ordering index_of_history_node = self.history.add_node( HistoryNode( step=step, observation=None, depth=self._step_depth, context_used=context_used, ) ) # Call all subscribed callbacks await self.update_subscribers() # Try to run step and handle errors self._step_depth += 1 caught_error = False try: observation = await step(self.continue_sdk) except Exception as e: if ( index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted ): # If step was deleted/cancelled, don't show error or allow retry return None caught_error = True is_continue_custom_exception = ( issubclass(e.__class__, ContinueCustomException) or e.__class__.__name__ == ContinueCustomException.__name__ ) error_string = ( e.message if is_continue_custom_exception else "\n".join(traceback.format_exception(e)) ) error_title = ( e.title if is_continue_custom_exception else get_error_title(e) ) # Attach an InternalErrorObservation to the step and unhide it. logger.error(f"Error while running step: \n{error_string}\n{error_title}") posthog_logger.capture_event( "step error", { "error_message": error_string, "error_title": error_title, "step_name": step.name, "params": step.dict(), }, ) observation = InternalErrorObservation( error=error_string, title=error_title ) # Reveal this step, but hide all of the following steps (its substeps) step_was_hidden = step.hide step.hide = False i = self.history.get_current_index() while self.history.timeline[i].step.name != step.name: self.history.timeline[i].step.hide = True i -= 1 # i is now the index of the step that we want to show/rerun self.history.timeline[i].observation = observation self.history.timeline[i].active = False await self.update_subscribers() # ContinueCustomException can optionally specify a step to run on the error if is_continue_custom_exception and e.with_step is not None: await self._run_singular_step(e.with_step) # Wait for a retry signal and then resume the step self._active = False await self._retry_queue.get(str(i)) self._active = True # You might consider a "ignore and continue" button # want it to have same step depth, so have to decrement self._step_depth -= 1 copy_step = step.copy() copy_step.hide = step_was_hidden observation = await self._run_singular_step(copy_step) self._step_depth += 1 self._step_depth -= 1 # Add observation to history, unless already attached error observation if not caught_error and index_of_history_node < len(self.history.timeline): self.history.timeline[index_of_history_node].observation = observation self.history.timeline[index_of_history_node].active = False await self.update_subscribers() # Update its description async def update_description(): if self.continue_sdk.config.disable_summaries: return description = await step.describe(self.continue_sdk.models) if description is not None: step.description = description # Update subscribers with new description await self.update_subscribers() create_async_task( update_description(), on_error=lambda e: self.continue_sdk.run_step( DisplayErrorStep.from_exception(e) ), ) # Create the session title if not done yet if self.session_info is None or self.session_info.title is None: visible_nodes = list( filter(lambda node: not node.step.hide, self.history.timeline) ) user_input = None should_create_title = False for visible_node in visible_nodes: if isinstance(visible_node.step, UserInputStep): if user_input is None: user_input = visible_node.step.user_input else: # More than one user input, so don't create title should_create_title = False break elif user_input is None: continue else: # Already have user input, now have the next step should_create_title = True break # Only create the title if the step after the first input is done if should_create_title: create_async_task( self.create_title(backup=user_input), on_error=lambda e: self.continue_sdk.run_step( DisplayErrorStep.from_exception(e) ), ) return observation async def run_from_step(self, step: "Step"): # if self._active: # raise RuntimeError("Autopilot is already running") self._active = True next_step = step is_future_step = False while not (next_step is None or self._should_halt): if is_future_step: # If future step, then we are replaying and need to delete the step from history so it can be replaced self.history.remove_current_and_substeps() await self._run_singular_step(next_step, is_future_step) if next_step := self.policy.next(self.continue_sdk.config, self.history): is_future_step = False elif next_step := self.history.take_next_step(): is_future_step = True else: next_step = None self._active = False # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools await self.update_subscribers() async def run_from_observation(self, observation: Observation): next_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(next_step) async def run_policy(self): first_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(first_step) async def _request_halt(self): if self._active: self._should_halt = True while self._active: time.sleep(0.1) self._should_halt = False return None def set_current_session_title(self, title: str): self.session_info = SessionInfo( title=title, session_id=self.ide.session_id, date_created=str(time.time()), workspace_directory=self.ide.workspace_directory, ) async def create_title(self, backup: str = None): # Use the first input and first response to create title for session info, and make the session saveable if self.session_info is not None and self.session_info.title is not None: return if self.continue_sdk.config.disable_summaries: if backup is not None: title = backup else: title = "New Session" else: chat_history = list( map(lambda x: x.dict(), await self.continue_sdk.get_chat_context()) ) chat_history_str = template_alpaca_messages(chat_history) title = await self.continue_sdk.models.summarize.complete( f"{chat_history_str}\n\nGive a short title to describe the above chat session. Do not put quotes around the title. Do not use more than 6 words. The title is: ", max_tokens=20, log=False, ) title = remove_quotes_and_escapes(title) self.set_current_session_title(title) await self.update_subscribers() dev_data_logger.capture("new_session", self.session_info.dict()) async def accept_user_input(self, user_input: str): self._main_user_input_queue.append(user_input) # await self.update_subscribers() if len(self._main_user_input_queue) > 1: return # await self._request_halt() # Just run the step that takes user input, and # then up to the policy to decide how to deal with it. self._main_user_input_queue.pop(0) # await self.update_subscribers() await self.run_from_step(UserInputStep(user_input=user_input)) while len(self._main_user_input_queue) > 0: await self.run_from_step( UserInputStep(user_input=self._main_user_input_queue.pop(0)) ) async def accept_refinement_input(self, user_input: str, index: int): await self._request_halt() await self.reverse_to_index(index) await self.run_from_step(UserInputStep(user_input=user_input)) async def reject_diff(self, step_index: int): # Hide the edit step and the UserInputStep before it self.history.timeline[step_index].step.hide = True for i in range(step_index - 1, -1, -1): if isinstance(self.history.timeline[i].step, UserInputStep): self.history.timeline[i].step.hide = True break await self.update_subscribers() async def select_context_item(self, id: str, query: str): await self.context_manager.select_context_item(id, query) await self.update_subscribers() async def select_context_item_at_index(self, id: str, query: str, index: int): # TODO: This is different from how it works for the main input # Ideally still tracked through the ContextProviders # so they can watch for duplicates context_item = await self.context_manager.get_context_item(id, query) if context_item is None: return self.history.timeline[index].context_used.append(context_item) await self.update_subscribers() async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): edit_config_property(key_path, value) await self.update_subscribers() _saved_context_groups: Dict[str, List[ContextItem]] = {} def _persist_context_groups(self): context_groups_file = getSavedContextGroupsPath() if os.path.exists(context_groups_file): with open(context_groups_file, "w") as f: dict_to_save = { title: [item.dict() for item in context_items] for title, context_items in self._saved_context_groups.items() } json.dump(dict_to_save, f) async def save_context_group(self, title: str, context_items: List[ContextItem]): self._saved_context_groups[title] = context_items await self.update_subscribers() # Update saved context groups self._persist_context_groups() posthog_logger.capture_event( "save_context_group", {"title": title, "length": len(context_items)} ) async def select_context_group(self, id: str): if id not in self._saved_context_groups: logger.warning(f"Context group {id} not found") return context_group = self._saved_context_groups[id] await self.context_manager.clear_context() for item in context_group: await self.context_manager.manually_add_context_item(item) await self.update_subscribers() posthog_logger.capture_event( "select_context_group", {"title": id, "length": len(context_group)} ) dev_data_logger.capture( "select_context_group", {"title": id, "items": context_group} ) async def delete_context_group(self, id: str): if id not in self._saved_context_groups: logger.warning(f"Context group {id} not found") return del self._saved_context_groups[id] await self.update_subscribers() # Update saved context groups self._persist_context_groups() posthog_logger.capture_event("delete_context_group", {"title": id})