diff options
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 31 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/state_manager.py | 21 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 13 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/react.py | 37 | 
4 files changed, 91 insertions, 11 deletions
| diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 8aea8de7..0b0c0fcd 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -9,6 +9,8 @@ from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeSte  from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe  from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep  from ..steps.comment_code import CommentCodeStep +from ..steps.react import NLDecisionStep +from ..steps.chat import SimpleChatStep  class DemoPolicy(Policy): @@ -26,19 +28,26 @@ class DemoPolicy(Policy):          observation = history.get_current().observation          if observation is not None and isinstance(observation, UserInputObservation):              # This could be defined with ObservationTypePolicy. Ergonomics not right though. -            if "/pytest" in observation.user_input.lower(): -                return WritePytestsRecipe(instructions=observation.user_input) -            elif "/dlt" in observation.user_input.lower() or " dlt" in observation.user_input.lower(): +            user_input = observation.user_input +            if "/pytest" in user_input.lower(): +                return WritePytestsRecipe(instructions=user_input) +            elif "/dlt" in user_input.lower() or " dlt" in user_input.lower():                  return CreatePipelineRecipe() -            elif "/comment" in observation.user_input.lower(): +            elif "/comment" in user_input.lower():                  return CommentCodeStep() -            elif "/ask" in observation.user_input: -                return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) -            elif "/edit" in observation.user_input: -                return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) -            elif "/step" in observation.user_input: -                return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:])) -            return EditHighlightedCodeStep(user_input=observation.user_input) +            elif "/ask" in user_input: +                return AnswerQuestionChroma(question=" ".join(user_input.split(" ")[1:])) +            elif "/edit" in user_input: +                return EditFileChroma(request=" ".join(user_input.split(" ")[1:])) +            elif "/step" in user_input: +                return ContinueStepStep(prompt=" ".join(user_input.split(" ")[1:])) +            # return EditHighlightedCodeStep(user_input=user_input) +            return NLDecisionStep(user_input=user_input, steps=[ +                EditHighlightedCodeStep(user_input=user_input), +                AnswerQuestionChroma(question=user_input), +                EditFileChroma(request=user_input), +                SimpleChatStep(user_input=user_input) +            ])          state = history.get_current() diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py new file mode 100644 index 00000000..c9bd760b --- /dev/null +++ b/continuedev/src/continuedev/server/state_manager.py @@ -0,0 +1,21 @@ +from typing import Any, List, Tuple, Union +from fastapi import WebSocket +from pydantic import BaseModel +from ..core.main import FullState + +# State updates represented as (path, replacement) pairs +StateUpdate = Tuple[List[Union[str, int]], Any] + + +class StateManager: +    """ +    A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. +    """ + +    def __init__(self, ws: WebSocket): +        self.ws = ws + +    def _send_update(self, updates: List[StateUpdate]): +        self.ws.send_json( +            [update.dict() for update in updates] +        ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py new file mode 100644 index 00000000..817e10dd --- /dev/null +++ b/continuedev/src/continuedev/steps/chat.py @@ -0,0 +1,13 @@ +from textwrap import dedent +from typing import List +from ..core.main import Step +from ..core.sdk import ContinueSDK +from .main import MessageStep + + +class SimpleChatStep(Step): +    user_input: str + +    async def run(self, sdk: ContinueSDK): +        # TODO: With history +        self.description = await sdk.models.gpt35.complete(self.user_input) diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py new file mode 100644 index 00000000..411adc87 --- /dev/null +++ b/continuedev/src/continuedev/steps/react.py @@ -0,0 +1,37 @@ +from textwrap import dedent +from typing import List +from ..core.main import Step +from ..core.sdk import ContinueSDK +from .main import MessageStep + + +class NLDecisionStep(Step): +    user_input: str +    steps: List[Step] + +    async def run(self, sdk: ContinueSDK): +        step_descriptions = "\n".join([ +            f"- {step.name}: {step.description}" +            for step in self.steps +        ]) +        prompt = dedent(f"""\ +                        The following steps are available, in the format "- [step name]: [step description]": +                        {step_descriptions} +                         +                        The user gave the following input: +                         +                        {self.user_input} +                         +                        Select the step which should be taken next. Say only the name of the selected step:""") + +        resp = (await sdk.models.gpt35.complete(prompt)).lower() + +        step_to_run = None +        for step in self.steps: +            if step.name in resp: +                step_to_run = step + +        step_to_run = step_to_run or MessageStep( +            message="Unable to decide the next step") + +        await sdk.run_step(step_to_run) | 
