import json import os import time import traceback import uuid from functools import cached_property from typing import Callable, Coroutine, Dict, List, Optional import redbaron from aiohttp import ClientPayloadError from openai import error as openai_errors from pydantic import root_validator from ..libs.llm.prompts.chat import template_alpaca_messages from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.edit_config import edit_config_property from ..libs.util.logging import logger from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes from ..libs.util.telemetry import posthog_logger from ..libs.util.traceback.traceback_parsers import ( get_javascript_traceback, get_python_traceback, ) from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from ..models.main import ContinueBaseModel from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..plugins.policies.default import DefaultPolicy from ..plugins.steps.on_traceback import DefaultOnTracebackStep from ..server.ide_protocol import AbstractIdeProtocolServer from ..server.meilisearch_server import get_meilisearch_url, stop_meilisearch from .config import ContinueConfig from .context import ContextManager from .main import ( Context, ContextItem, ContinueCustomException, FullState, History, HistoryNode, Policy, SessionInfo, Step, ) from .observation import InternalErrorObservation, Observation from .sdk import ContinueSDK from .steps import DisplayErrorStep, ManualEditStep, ReversibleStep, UserInputStep def get_error_title(e: Exception) -> str: if isinstance(e, openai_errors.APIError): return "OpenAI is overloaded with requests. Please try again." elif isinstance(e, openai_errors.RateLimitError): return "This OpenAI API key has been rate limited. Please try again." elif isinstance(e, openai_errors.Timeout): return "OpenAI timed out. Please try again." elif ( isinstance(e, openai_errors.InvalidRequestError) and e.code == "context_length_exceeded" ): return e._message elif isinstance(e, ClientPayloadError): return "The request failed. Please try again." elif isinstance(e, openai_errors.APIConnectionError): return 'The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' elif isinstance(e, openai_errors.InvalidRequestError): return "Invalid request sent to OpenAI. Please try again." elif "rate_limit_ip_middleware" in e.__str__(): return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings." elif e.__str__().startswith("Cannot connect to host"): return ( "The request failed. Please check your internet connection and try again." ) return e.__str__() or e.__repr__() class Autopilot(ContinueBaseModel): ide: AbstractIdeProtocolServer policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Optional[FullState] = None session_info: Optional[SessionInfo] = None context_manager: ContextManager = ContextManager() continue_sdk: ContinueSDK = None _on_update_callbacks: List[Callable[[FullState], None]] = [] _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() started: bool = False async def load( self, config: Optional[ContinueConfig] = None, only_reloading: bool = False ): self.continue_sdk = await ContinueSDK.create(self, config=config) if override_policy := self.continue_sdk.config.policy_override: self.policy = override_policy # Load documents into the search index logger.debug("Starting context manager") await self.context_manager.start( self.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=self.ide), FileContextProvider(workspace_dir=self.ide.workspace_directory), ], self.continue_sdk, only_reloading=only_reloading, ) async def start( self, full_state: Optional[FullState] = None, config: Optional[ContinueConfig] = None, ): await self.load(config=config, only_reloading=False) if full_state is not None: self.history = full_state.history self.session_info = full_state.session_info # Load saved context groups context_groups_file = getSavedContextGroupsPath() try: with open(context_groups_file, "r") as f: json_ob = json.load(f) for title, context_group in json_ob.items(): self._saved_context_groups[title] = [ ContextItem(**item) for item in context_group ] except Exception as e: logger.warning( f"Failed to load saved_context_groups.json: {e}. Reverting to empty list." ) self._saved_context_groups = {} self.started = True async def reload_config(self): await self.load(config=None, only_reloading=True) await self.update_subscribers() async def cleanup(self): stop_meilisearch() class Config: arbitrary_types_allowed = True keep_untouched = (cached_property,) @root_validator(pre=True) def fill_in_values(cls, values): full_state: FullState = values.get("full_state") if full_state is not None: values["history"] = full_state.history return values async def get_full_state(self) -> FullState: full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ "code" ].adding_highlighted_code if "code" in self.context_manager.context_providers else False, selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], session_info=self.session_info, config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, context_providers=self.context_manager.get_provider_descriptions(), meilisearch_url=get_meilisearch_url(), ) self.full_state = full_state return full_state def get_available_slash_commands(self) -> List[Dict]: custom_commands = ( list( map( lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.custom_commands, ) ) or [] ) slash_commands = ( list( map( lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands, ) ) or [] ) cmds = custom_commands + slash_commands cmds.sort(key=lambda x: x["name"] == "edit", reverse=True) return cmds async def clear_history(self): # Reset history self.history = History.from_empty() self._main_user_input_queue = [] self._active = False # Clear context # await self.context_manager.clear_context() await self.update_subscribers() def on_update(self, callback: Coroutine["FullState", None, None]): """Subscribe to changes to state""" self._on_update_callbacks.append(callback) async def update_subscribers(self): full_state = await self.get_full_state() for callback in self._on_update_callbacks: await callback(full_state) def give_user_input(self, input: str, index: int): self._user_input_queue.post(str(index), input) async def wait_for_user_input(self) -> str: self._active = False await self.update_subscribers() user_input = await self._user_input_queue.get(str(self.history.current_index)) self._active = True await self.update_subscribers() return user_input _manual_edits_buffer: List[FileEditWithFullContents] = [] async def reverse_to_index(self, index: int): try: while self.history.get_current_index() >= index: current_step = self.history.get_current().step self.history.step_back() if issubclass(current_step.__class__, ReversibleStep): await current_step.reverse(self.continue_sdk) await self.update_subscribers() except Exception as e: logger.debug(e) def handle_manual_edits(self, edits: List[FileEditWithFullContents]): for edit in edits: self._manual_edits_buffer.append(edit) # TODO: You're storing a lot of unnecessary data here. Can compress into EditDiffs on the spot, and merge. # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) # Note that this is being overridden to do nothing in DemoAgent async def handle_command_output(self, output: str): get_traceback_funcs = [get_python_traceback, get_javascript_traceback] for get_tb_func in get_traceback_funcs: traceback = get_tb_func(output) if ( traceback is not None and self.continue_sdk.config.on_traceback is not None ): step = self.continue_sdk.config.on_traceback(output=output) await self._run_singular_step(step) async def handle_debug_terminal(self, content: str): """Run the debug terminal step""" # step = self.continue_sdk.config.on_traceback(output=content) step = DefaultOnTracebackStep(output=content) await self._run_singular_step(step) async def handle_highlighted_code( self, range_in_files: List[RangeInFileWithContents], edit: Optional[bool] = False, ): if "code" not in self.context_manager.context_providers: return # Add to context manager await self.context_manager.context_providers["code"].handle_highlighted_code( range_in_files, edit ) await self.update_subscribers() _step_depth: int = 0 async def retry_at_index(self, index: int): self.history.timeline[index].step.hide = True self._retry_queue.post(str(index), None) async def delete_at_index(self, index: int): if not self.history.timeline[index].active: self.history.timeline[index].step.hide = True self.history.timeline[index].deleted = True self.history.timeline[index].active = False await self.update_subscribers() async def edit_step_at_index(self, user_input: str, index: int): node_to_rerun = self.history.timeline[index].copy() step_to_rerun = node_to_rerun.step step_to_rerun.user_input = user_input step_to_rerun.description = user_input # Halt the agent's currently running jobs (delete them) while len(self.history.timeline) > index: # Remove from timeline node_to_delete = self.history.timeline.pop() # Delete so it is stopped if in the middle of running node_to_delete.deleted = True self.history.current_index = index - 1 # Set the context to the context used by that step await self.context_manager.clear_context() for context_item in node_to_rerun.context_used: await self.context_manager.manually_add_context_item(context_item) await self.update_subscribers() # Rerun from the current step await self.run_from_step(step_to_rerun) async def delete_context_with_ids( self, ids: List[str], index: Optional[int] = None ): if index is None: await self.context_manager.delete_context_with_ids(ids) else: self.history.timeline[index].context_used = list( filter( lambda item: item.description.id.to_string() not in ids, self.history.timeline[index].context_used, ) ) await self.update_subscribers() async def toggle_adding_highlighted_code(self): if "code" not in self.context_manager.context_providers: return self.context_manager.context_providers[ "code" ].adding_highlighted_code = not self.context_manager.context_providers[ "code" ].adding_highlighted_code await self.update_subscribers() async def set_editing_at_ids(self, ids: List[str]): if "code" not in self.context_manager.context_providers: return await self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() async def _run_singular_step( self, step: "Step", is_future_step: bool = False ) -> Coroutine[Observation, None, None]: # Allow config to set disallowed steps if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: return None # If a parent step is deleted/cancelled, don't run this step # TODO: This was problematic because when running a step after deleting one, it seemed to think that was the parent # last_depth = self._step_depth # i = self.history.current_index # while i >= 0 and self.history.timeline[i].depth == last_depth - 1: # if self.history.timeline[i].deleted: # return None # last_depth = self.history.timeline[i].depth # i -= 1 # Log the context and step to dev data context_used = await self.context_manager.get_selected_items() posthog_logger.capture_event( "step run", {"step_name": step.name, "params": step.dict()} ) step_id = uuid.uuid4().hex dev_data_logger.capture( "step_run", {"step_name": step.name, "params": step.dict(), "step_id": step_id}, ) dev_data_logger.capture( "context_used", { "context": list( map( lambda item: item.dict(), context_used, ) ), "step_id": step_id, }, ) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep if len(self._manual_edits_buffer) > 0: manualEditsStep = ManualEditStep.from_sequence( self._manual_edits_buffer ) self._manual_edits_buffer = [] await self._run_singular_step(manualEditsStep) # Update history - do this first so we get top-first tree ordering index_of_history_node = self.history.add_node( HistoryNode( step=step, observation=None, depth=self._step_depth, context_used=context_used, ) ) # Call all subscribed callbacks await self.update_subscribers() # Try to run step and handle errors self._step_depth += 1 caught_error = False try: observation = await step(self.continue_sdk) except Exception as e: if ( index_of_history_node >= len(self.history.timeline) or self.history.timeline[index_of_history_node].deleted ): # If step was deleted/cancelled, don't show error or allow retry return None caught_error = True is_continue_custom_exception = ( issubclass(e.__class__, ContinueCustomException) or e.__class__.__name__ == ContinueCustomException.__name__ ) error_string = ( e.message if is_continue_custom_exception else "\n".join(traceback.format_exception(e)) ) error_title = ( e.title if is_continue_custom_exception else get_error_title(e) ) # Attach an InternalErrorObservation to the step and unhide it. logger.error(f"Error while running step: \n{error_string}\n{error_title}") posthog_logger.capture_event( "step error", { "error_message": error_string, "error_title": error_title, "step_name": step.name, "params": step.dict(), }, ) observation = InternalErrorObservation( error=error_string, title=error_title ) # Reveal this step, but hide all of the following steps (its substeps) step_was_hidden = step.hide step.hide = False i = self.history.get_current_index() while self.history.timeline[i].step.name != step.name: self.history.timeline[i].step.hide = True i -= 1 # i is now the index of the step that we want to show/rerun self.history.timeline[i].observation = observation self.history.timeline[i].active = False await self.update_subscribers() # ContinueCustomException can optionally specify a step to run on the error if is_continue_custom_exception and e.with_step is not None: await self._run_singular_step(e.with_step) # Wait for a retry signal and then resume the step self._active = False await self._retry_queue.get(str(i)) self._active = True # You might consider a "ignore and continue" button # want it to have same step depth, so have to decrement self._step_depth -= 1 copy_step = step.copy() copy_step.hide = step_was_hidden observation = await self._run_singular_step(copy_step) self._step_depth += 1 self._step_depth -= 1 # Add observation to history, unless already attached error observation if not caught_error and index_of_history_node < len(self.history.timeline): self.history.timeline[index_of_history_node].observation = observation self.history.timeline[index_of_history_node].active = False await self.update_subscribers() # Update its description async def update_description(): if self.continue_sdk.config.disable_summaries: return description = await step.describe(self.continue_sdk.models) if description is not None: step.description = description # Update subscribers with new description await self.update_subscribers() create_async_task( update_description(), on_error=lambda e: self.continue_sdk.run_step( DisplayErrorStep.from_exception(e) ), ) # Create the session title if not done yet if self.session_info is None or self.session_info.title is None: visible_nodes = list( filter(lambda node: not node.step.hide, self.history.timeline) ) user_input = None should_create_title = False for visible_node in visible_nodes: if isinstance(visible_node.step, UserInputStep): if user_input is None: user_input = visible_node.step.user_input else: # More than one user input, so don't create title should_create_title = False break elif user_input is None: continue else: # Already have user input, now have the next step should_create_title = True break # Only create the title if the step after the first input is done if should_create_title: create_async_task( self.create_title(backup=user_input), on_error=lambda e: self.continue_sdk.run_step( DisplayErrorStep.from_exception(e) ), ) return observation async def run_from_step(self, step: "Step"): # if self._active: # raise RuntimeError("Autopilot is already running") self._active = True next_step = step is_future_step = False while not (next_step is None or self._should_halt): if is_future_step: # If future step, then we are replaying and need to delete the step from history so it can be replaced self.history.remove_current_and_substeps() await self._run_singular_step(next_step, is_future_step) if next_step := self.policy.next(self.continue_sdk.config, self.history): is_future_step = False elif next_step := self.history.take_next_step(): is_future_step = True else: next_step = None self._active = False # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools await self.update_subscribers() async def run_from_observation(self, observation: Observation): next_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(next_step) async def run_policy(self): first_step = self.policy.next(self.continue_sdk.config, self.history) await self.run_from_step(first_step) async def _request_halt(self): if self._active: self._should_halt = True while self._active: time.sleep(0.1) self._should_halt = False return None def set_current_session_title(self, title: str): self.session_info = SessionInfo( title=title, session_id=self.ide.session_id, date_created=str(time.time()), workspace_directory=self.ide.workspace_directory, ) async def create_title(self, backup: str = None): # Use the first input and first response to create title for session info, and make the session saveable if self.session_info is not None and self.session_info.title is not None: return if self.continue_sdk.config.disable_summaries: if backup is not None: title = backup else: title = "New Session" else: chat_history = list( map(lambda x: x.dict(), await self.continue_sdk.get_chat_context()) ) chat_history_str = template_alpaca_messages(chat_history) title = await self.continue_sdk.models.summarize.complete( f"{chat_history_str}\n\nGive a short title to describe the above chat session. Do not put quotes around the title. Do not use more than 6 words. The title is: ", max_tokens=20, log=False, ) title = remove_quotes_and_escapes(title) self.set_current_session_title(title) await self.update_subscribers() dev_data_logger.capture("new_session", self.session_info.dict()) async def accept_user_input(self, user_input: str): self._main_user_input_queue.append(user_input) # await self.update_subscribers() if len(self._main_user_input_queue) > 1: return # await self._request_halt() # Just run the step that takes user input, and # then up to the policy to decide how to deal with it. self._main_user_input_queue.pop(0) # await self.update_subscribers() await self.run_from_step(UserInputStep(user_input=user_input)) while len(self._main_user_input_queue) > 0: await self.run_from_step( UserInputStep(user_input=self._main_user_input_queue.pop(0)) ) async def accept_refinement_input(self, user_input: str, index: int): await self._request_halt() await self.reverse_to_index(index) await self.run_from_step(UserInputStep(user_input=user_input)) async def reject_diff(self, step_index: int): # Hide the edit step and the UserInputStep before it self.history.timeline[step_index].step.hide = True for i in range(step_index - 1, -1, -1): if isinstance(self.history.timeline[i].step, UserInputStep): self.history.timeline[i].step.hide = True break await self.update_subscribers() async def select_context_item(self, id: str, query: str): await self.context_manager.select_context_item(id, query) await self.update_subscribers() async def select_context_item_at_index(self, id: str, query: str, index: int): # TODO: This is different from how it works for the main input # Ideally still tracked through the ContextProviders # so they can watch for duplicates context_item = await self.context_manager.get_context_item(id, query) if context_item is None: return self.history.timeline[index].context_used.append(context_item) await self.update_subscribers() async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): edit_config_property(key_path, value) await self.update_subscribers() _saved_context_groups: Dict[str, List[ContextItem]] = {} def _persist_context_groups(self): context_groups_file = getSavedContextGroupsPath() if os.path.exists(context_groups_file): with open(context_groups_file, "w") as f: dict_to_save = { title: [item.dict() for item in context_items] for title, context_items in self._saved_context_groups.items() } json.dump(dict_to_save, f) async def save_context_group(self, title: str, context_items: List[ContextItem]): self._saved_context_groups[title] = context_items await self.update_subscribers() # Update saved context groups self._persist_context_groups() posthog_logger.capture_event( "save_context_group", {"title": title, "length": len(context_items)} ) async def select_context_group(self, id: str): if id not in self._saved_context_groups: logger.warning(f"Context group {id} not found") return context_group = self._saved_context_groups[id] await self.context_manager.clear_context() for item in context_group: await self.context_manager.manually_add_context_item(item) await self.update_subscribers() posthog_logger.capture_event( "select_context_group", {"title": id, "length": len(context_group)} ) dev_data_logger.capture( "select_context_group", {"title": id, "items": context_group} ) async def delete_context_group(self, id: str): if id not in self._saved_context_groups: logger.warning(f"Context group {id} not found") return del self._saved_context_groups[id] await self.update_subscribers() # Update saved context groups self._persist_context_groups() posthog_logger.capture_event("delete_context_group", {"title": id})