summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-03 14:20:21 -0400
committerNate Sesti <sestinj@gmail.com>2023-06-03 14:20:21 -0400
commitb8067876bc5dd425d491863bd5338f325fea35ed (patch)
tree4c3167043c2bffaff5931b77e080a8f84ad52c96
parent4f3ceee573268fbe9db80fea372198523b5757a6 (diff)
downloadsncontinue-b8067876bc5dd425d491863bd5338f325fea35ed.tar.gz
sncontinue-b8067876bc5dd425d491863bd5338f325fea35ed.tar.bz2
sncontinue-b8067876bc5dd425d491863bd5338f325fea35ed.zip
error handling and step retry
-rw-r--r--continuedev/src/continuedev/core/agent.py180
-rw-r--r--continuedev/src/continuedev/core/autopilot.py52
-rw-r--r--continuedev/src/continuedev/core/main.py9
-rw-r--r--continuedev/src/continuedev/core/observation.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py9
-rw-r--r--continuedev/src/continuedev/server/gui.py6
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py4
-rw-r--r--extension/react-app/src/components/StepContainer.tsx43
-rw-r--r--extension/react-app/src/hooks/useContinueGUIProtocol.ts4
-rw-r--r--extension/react-app/src/tabs/gui.tsx4
10 files changed, 106 insertions, 209 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py
deleted file mode 100644
index 1996abb1..00000000
--- a/continuedev/src/continuedev/core/agent.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import traceback
-import time
-from typing import Callable, Coroutine, List
-from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.llm import LLM
-from .observation import Observation
-from ..server.ide_protocol import AbstractIdeProtocolServer
-from ..libs.util.queue import AsyncSubscriptionQueue
-from ..models.main import ContinueBaseModel
-from .main import Policy, History, FullState, Step, HistoryNode
-from ..libs.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
-from .sdk import ContinueSDK
-
-
-class Autopilot(ContinueBaseModel):
- llm: LLM
- policy: Policy
- ide: AbstractIdeProtocolServer
- history: History = History.from_empty()
- _on_update_callbacks: List[Callable[[FullState], None]] = []
-
- _active: bool = False
- _should_halt: bool = False
- _main_user_input_queue: List[str] = []
-
- _user_input_queue = AsyncSubscriptionQueue()
-
- class Config:
- arbitrary_types_allowed = True
-
- def get_full_state(self) -> FullState:
- return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue)
-
- def on_update(self, callback: Callable[["FullState"], None]):
- """Subscribe to changes to state"""
- self._on_update_callbacks.append(callback)
-
- def update_subscribers(self):
- full_state = self.get_full_state()
- for callback in self._on_update_callbacks:
- callback(full_state)
-
- def __get_step_params(self, step: "Step"):
- return ContinueSDK(autopilot=self, llm=self.llm.with_system_message(step.system_message))
-
- def give_user_input(self, input: str, index: int):
- self._user_input_queue.post(index, input)
-
- async def wait_for_user_input(self) -> str:
- self._active = False
- self.update_subscribers()
- user_input = await self._user_input_queue.get(self.history.current_index)
- self._active = True
- self.update_subscribers()
- return user_input
-
- _manual_edits_buffer: List[FileEditWithFullContents] = []
-
- async def reverse_to_index(self, index: int):
- try:
- while self.history.get_current_index() >= index:
- current_step = self.history.get_current().step
- self.history.step_back()
- if issubclass(current_step.__class__, ReversibleStep):
- await current_step.reverse(self.__get_step_params(current_step))
-
- self.update_subscribers()
- except Exception as e:
- print(e)
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- for edit in edits:
- self._manual_edits_buffer.append(edit)
- # TODO: You're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
-
- def handle_traceback(self, traceback: str):
- raise NotImplementedError
-
- _step_depth: int = 0
-
- async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
- if not is_future_step:
- # Check manual edits buffer, clear out if needed by creating a ManualEditStep
- if len(self._manual_edits_buffer) > 0:
- manualEditsStep = ManualEditStep.from_sequence(
- self._manual_edits_buffer)
- self._manual_edits_buffer = []
- await self._run_singular_step(manualEditsStep)
-
- # Update history - do this first so we get top-first tree ordering
- self.history.add_node(HistoryNode(
- step=step, observation=None, depth=self._step_depth))
-
- # Run step
- self._step_depth += 1
- observation = await step(self.__get_step_params(step))
- self._step_depth -= 1
-
- # Add observation to history
- self.history.get_current().observation = observation
-
- # Update its description
- step._set_description(await step.describe(self.llm))
-
- # Call all subscribed callbacks
- self.update_subscribers()
-
- return observation
-
- async def run_from_step(self, step: "Step"):
- # if self._active:
- # raise RuntimeError("Autopilot is already running")
- self._active = True
-
- next_step = step
- is_future_step = False
- while not (next_step is None or self._should_halt):
- try:
- if is_future_step:
- # If future step, then we are replaying and need to delete the step from history so it can be replaced
- self.history.remove_current_and_substeps()
-
- observation = await self._run_singular_step(next_step, is_future_step)
- if next_step := self.policy.next(self.history):
- is_future_step = False
- elif next_step := self.history.take_next_step():
- is_future_step = True
- else:
- next_step = None
-
- except Exception as e:
- print(
- f"Error while running step: \n{''.join(traceback.format_tb(e.__traceback__))}\n{e}")
- next_step = None
-
- self._active = False
-
- # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools
- for callback in self._on_update_callbacks:
- callback(None)
-
- async def run_from_observation(self, observation: Observation):
- next_step = self.policy.next(self.history)
- await self.run_from_step(next_step)
-
- async def run_policy(self):
- first_step = self.policy.next(self.history)
- await self.run_from_step(first_step)
-
- async def _request_halt(self):
- if self._active:
- self._should_halt = True
- while self._active:
- time.sleep(0.1)
- self._should_halt = False
- return None
-
- async def accept_user_input(self, user_input: str):
- self._main_user_input_queue.append(user_input)
- self.update_subscribers()
-
- if len(self._main_user_input_queue) > 1:
- return
-
- # await self._request_halt()
- # Just run the step that takes user input, and
- # then up to the policy to decide how to deal with it.
- self._main_user_input_queue.pop(0)
- self.update_subscribers()
- await self.run_from_step(UserInputStep(user_input=user_input))
-
- while len(self._main_user_input_queue) > 0:
- await self.run_from_step(UserInputStep(
- user_input=self._main_user_input_queue.pop(0)))
-
- async def accept_refinement_input(self, user_input: str, index: int):
- await self._request_halt()
- await self.reverse_to_index(index)
- await self.run_from_step(UserInputStep(user_input=user_input))
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 85f65dc3..db06c975 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -3,7 +3,7 @@ import time
from typing import Callable, Coroutine, List
from ..models.filesystem_edit import FileEditWithFullContents
from ..libs.llm import LLM
-from .observation import Observation
+from .observation import Observation, InternalErrorObservation
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..models.main import ContinueBaseModel
@@ -77,6 +77,11 @@ class Autopilot(ContinueBaseModel):
_step_depth: int = 0
+ async def retry_at_index(self, index: int):
+ last_step = self.history.pop_last_step()
+ await self.update_subscribers()
+ await self._run_singular_step(last_step)
+
async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
capture_event(
'step run', {'step_name': step.name, 'params': step.dict()})
@@ -96,14 +101,28 @@ class Autopilot(ContinueBaseModel):
# Call all subscribed callbacks
await self.update_subscribers()
- # Run step
+ # Try to run step and handle errors
self._step_depth += 1
- observation = await step(ContinueSDK(self))
+
+ try:
+ observation = await step(ContinueSDK(self))
+ except Exception as e:
+ # Attach an InternalErrorObservation to the step and unhide it.
+ error_string = '\n\n'.join(
+ traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}"
+ print(
+ f"Error while running step: \n{error_string}\n{e}")
+
+ observation = InternalErrorObservation(
+ error=error_string)
+ step.hide = False
+
self._step_depth -= 1
# Add observation to history
self.history.get_last_at_depth(
self._step_depth, include_current=True).observation = observation
+ await self.update_subscribers()
# Update its description
async def update_description():
@@ -122,22 +141,17 @@ class Autopilot(ContinueBaseModel):
next_step = step
is_future_step = False
while not (next_step is None or self._should_halt):
- try:
- if is_future_step:
- # If future step, then we are replaying and need to delete the step from history so it can be replaced
- self.history.remove_current_and_substeps()
-
- observation = await self._run_singular_step(next_step, is_future_step)
- if next_step := self.policy.next(self.history):
- is_future_step = False
- elif next_step := self.history.take_next_step():
- is_future_step = True
- else:
- next_step = None
-
- except Exception as e:
- print(
- f"Error while running step: \n{''.join(traceback.format_tb(e.__traceback__))}\n{e}")
+ if is_future_step:
+ # If future step, then we are replaying and need to delete the step from history so it can be replaced
+ self.history.remove_current_and_substeps()
+
+ await self._run_singular_step(next_step, is_future_step)
+
+ if next_step := self.policy.next(self.history):
+ is_future_step = False
+ elif next_step := self.history.take_next_step():
+ is_future_step = True
+ else:
next_step = None
self._active = False
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index a2336671..b2b97bae 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -67,6 +67,13 @@ class History(ContinueBaseModel):
return None
return state.observation
+ def pop_last_step(self) -> Union[HistoryNode, None]:
+ if self.current_index < 0:
+ return None
+ node = self.timeline.pop(self.current_index)
+ self.current_index -= 1
+ return node.step
+
@classmethod
def from_empty(cls):
return cls(timeline=[], current_index=-1)
@@ -118,7 +125,7 @@ class Step(ContinueBaseModel):
if self._description is not None:
d["description"] = self._description
else:
- d["description"] = self.name
+ d["description"] = "`Description loading...`"
return d
@validator("name", pre=True, always=True)
diff --git a/continuedev/src/continuedev/core/observation.py b/continuedev/src/continuedev/core/observation.py
index fef04311..b6117236 100644
--- a/continuedev/src/continuedev/core/observation.py
+++ b/continuedev/src/continuedev/core/observation.py
@@ -33,3 +33,7 @@ class TextObservation(Observation):
if v is None:
return ""
return v
+
+
+class InternalErrorObservation(Observation):
+ error: str
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 83852d27..734da160 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -22,4 +22,11 @@ class HuggingFaceInferenceAPI(LLM):
"return_full_text": False,
}
})
- return response.json()[0]["generated_text"]
+ data = response.json()
+
+ # Error if the response is not a list
+ if not isinstance(data, list):
+ raise Exception(
+ "Hugging Face returned an error response: \n\n", data)
+
+ return data[0]["generated_text"]
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 3d1a5a82..b873a88f 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -75,6 +75,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_refinement_input(data["input"], data["index"])
elif message_type == "reverse_to_index":
self.on_reverse_to_index(data["index"])
+ elif message_type == "retry_at_index":
+ self.on_retry_at_index(data["index"])
except Exception as e:
print(e)
@@ -100,6 +102,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
asyncio.create_task(
self.session.autopilot.accept_refinement_input(input, index))
+ def on_retry_at_index(self, index: int):
+ asyncio.create_task(
+ self.session.autopilot.retry_at_index(index))
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
index e32d80ef..287f9e3b 100644
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ b/continuedev/src/continuedev/server/gui_protocol.py
@@ -26,3 +26,7 @@ class AbstractGUIProtocolServer(ABC):
@abstractmethod
async def send_state_update(self, state: dict):
"""Send a state update to the client"""
+
+ @abstractmethod
+ def on_retry_at_index(self, index: int):
+ """Called when the user requests a retry at a previous index"""
diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx
index f962cbc9..903f9b94 100644
--- a/extension/react-app/src/components/StepContainer.tsx
+++ b/extension/react-app/src/components/StepContainer.tsx
@@ -18,6 +18,7 @@ import {
ChevronDown,
ChevronRight,
Backward,
+ ArrowPath,
} from "@styled-icons/heroicons-outline";
import { HistoryNode } from "../../../schema/HistoryNode";
import ReactMarkdown from "react-markdown";
@@ -29,6 +30,7 @@ interface StepContainerProps {
inFuture: boolean;
onRefinement: (input: string) => void;
onUserInput: (input: string) => void;
+ onRetry: () => void;
}
const MainDiv = styled.div<{ stepDepth: number; inFuture: boolean }>`
@@ -135,19 +137,44 @@ function StepContainer(props: StepContainerProps) {
>
<Backward size="1.6em" onClick={props.onReverse}></Backward>
</HeaderButton> */}
+
+ {props.historyNode.observation?.error ? (
+ <HeaderButton
+ onClick={(e) => {
+ e.stopPropagation();
+ props.onRetry();
+ }}
+ >
+ <ArrowPath size="1.6em" onClick={props.onRetry}></ArrowPath>
+ </HeaderButton>
+ ) : (
+ <></>
+ )}
</HeaderDiv>
{open && (
- <pre>
- Step Details:
- <br />
- {JSON.stringify(props.historyNode.step, null, 2)}
- </pre>
+ <>
+ <pre className="overflow-scroll">
+ Step Details:
+ <br />
+ {JSON.stringify(props.historyNode.step, null, 2)}
+ </pre>
+ </>
)}
- <ReactMarkdown key={1} className="overflow-scroll">
- {props.historyNode.step.description as any}
- </ReactMarkdown>
+ {props.historyNode.observation?.error ? (
+ <>
+ Error while running step:
+ <br />
+ <pre className="overflow-scroll">
+ {props.historyNode.observation.error as string}
+ </pre>
+ </>
+ ) : (
+ <ReactMarkdown key={1} className="overflow-scroll">
+ {props.historyNode.step.description as any}
+ </ReactMarkdown>
+ )}
{props.historyNode.step.name === "Waiting for user input" && (
<input
diff --git a/extension/react-app/src/hooks/useContinueGUIProtocol.ts b/extension/react-app/src/hooks/useContinueGUIProtocol.ts
index a3a1d0c9..f27895fb 100644
--- a/extension/react-app/src/hooks/useContinueGUIProtocol.ts
+++ b/extension/react-app/src/hooks/useContinueGUIProtocol.ts
@@ -44,6 +44,10 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol {
}
});
}
+
+ retryAtIndex(index: number) {
+ this.messenger.send("retry_at_index", { index });
+ }
}
export default ContinueGUIClientProtocol;
diff --git a/extension/react-app/src/tabs/gui.tsx b/extension/react-app/src/tabs/gui.tsx
index 42ad4ed5..a08698a4 100644
--- a/extension/react-app/src/tabs/gui.tsx
+++ b/extension/react-app/src/tabs/gui.tsx
@@ -231,6 +231,10 @@ function GUI(props: GUIProps) {
onReverse={() => {
client?.reverseToIndex(index);
}}
+ onRetry={() => {
+ client?.retryAtIndex(index);
+ setWaitingForSteps(true);
+ }}
/>
);
})}