from typing import List, Tuple, Type from .steps.ty import CreatePipelineStep from .core import Step, Validator, Policy, History from .observation import Observation, TracebackObservation, UserInputObservation from .steps.main import EditCodeStep, EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep from .steps.nate import WritePytestsStep, CreateTableStep from .steps.chroma import AnswerQuestionChroma, EditFileChroma class DemoPolicy(Policy): ran_code_last: bool = False cmd: str def next(self, history: History) -> Step: observation = history.last_observation() if observation is not None and isinstance(observation, UserInputObservation): # This could be defined with ObservationTypePolicy. Ergonomics not right though. if " test" in observation.user_input.lower(): return WritePytestsStep(instructions=observation.user_input) elif "/dlt" in observation.user_input.lower() or " dlt" in observation.user_input.lower(): return CreatePipelineStep() elif "/table" in observation.user_input: return CreateTableStep(sql_str=" ".join(observation.user_input.split(" ")[1:])) elif "/ask" in observation.user_input: return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) elif "/edit" in observation.user_input: return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) return EditHighlightedCodeStep(user_input=observation.user_input) state = history.get_current() if state is None or not self.ran_code_last: self.ran_code_last = True return RunCodeStep(cmd=self.cmd) if observation is not None and isinstance(observation, TracebackObservation): self.ran_code_last = False return SolveTracebackStep(traceback=observation.traceback) else: return None class ObservationTypePolicy(Policy): def __init__(self, base_policy: Policy, observation_type: Type[Observation], step_type: Type[Step]): self.observation_type = observation_type self.step_type = step_type self.base_policy = base_policy def next(self, history: History) -> Step: observation = history.last_observation() if observation is not None and isinstance(observation, self.observation_type): return self.step_type(observation) return self.base_policy.next(history) class PolicyWrappedWithValidators(Policy): """Default is to stop, unless the validator tells what to do next""" index: int stage: int def __init__(self, base_policy: Policy, pairs: List[Tuple[Validator, Type[Step]]]): # Want to pass Type[Validator], or just the Validator? Question of where params are coming from. self.pairs = pairs self.index = len(pairs) self.validating = 0 self.base_policy = base_policy def next(self, history: History) -> Step: if self.index == len(self.pairs): self.index = 0 return self.base_policy.next(history) if self.stage == 0: # Running the validator at the current index for the first time validator, step = self.pairs[self.index] self.stage = 1 return validator elif self.stage == 1: # Previously ran the validator at the current index, now receiving its ValidatorObservation observation = history.last_observation() if observation.passed: self.stage = 0 self.index += 1 if self.index == len(self.pairs): self.index = 0 return self.base_policy.next(history) else: return self.pairs[self.index][0] else: _, step_type = self.pairs[self.index] return step_type(observation)