From 2070fee3a47cf1e8d566ad3452ee64f0fc03dd5c Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Sun, 11 Jun 2023 12:26:42 -0400
Subject: Step to decide next step form user input

---
 continuedev/src/continuedev/core/policy.py         | 31 +++++++++++-------
 .../src/continuedev/server/state_manager.py        | 21 ++++++++++++
 continuedev/src/continuedev/steps/chat.py          | 13 ++++++++
 continuedev/src/continuedev/steps/react.py         | 37 ++++++++++++++++++++++
 4 files changed, 91 insertions(+), 11 deletions(-)
 create mode 100644 continuedev/src/continuedev/server/state_manager.py
 create mode 100644 continuedev/src/continuedev/steps/chat.py
 create mode 100644 continuedev/src/continuedev/steps/react.py

(limited to 'continuedev/src')

diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 8aea8de7..0b0c0fcd 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -9,6 +9,8 @@ from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeSte
 from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe
 from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep
 from ..steps.comment_code import CommentCodeStep
+from ..steps.react import NLDecisionStep
+from ..steps.chat import SimpleChatStep
 
 
 class DemoPolicy(Policy):
@@ -26,19 +28,26 @@ class DemoPolicy(Policy):
         observation = history.get_current().observation
         if observation is not None and isinstance(observation, UserInputObservation):
             # This could be defined with ObservationTypePolicy. Ergonomics not right though.
-            if "/pytest" in observation.user_input.lower():
-                return WritePytestsRecipe(instructions=observation.user_input)
-            elif "/dlt" in observation.user_input.lower() or " dlt" in observation.user_input.lower():
+            user_input = observation.user_input
+            if "/pytest" in user_input.lower():
+                return WritePytestsRecipe(instructions=user_input)
+            elif "/dlt" in user_input.lower() or " dlt" in user_input.lower():
                 return CreatePipelineRecipe()
-            elif "/comment" in observation.user_input.lower():
+            elif "/comment" in user_input.lower():
                 return CommentCodeStep()
-            elif "/ask" in observation.user_input:
-                return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
-            elif "/edit" in observation.user_input:
-                return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
-            elif "/step" in observation.user_input:
-                return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))
-            return EditHighlightedCodeStep(user_input=observation.user_input)
+            elif "/ask" in user_input:
+                return AnswerQuestionChroma(question=" ".join(user_input.split(" ")[1:]))
+            elif "/edit" in user_input:
+                return EditFileChroma(request=" ".join(user_input.split(" ")[1:]))
+            elif "/step" in user_input:
+                return ContinueStepStep(prompt=" ".join(user_input.split(" ")[1:]))
+            # return EditHighlightedCodeStep(user_input=user_input)
+            return NLDecisionStep(user_input=user_input, steps=[
+                EditHighlightedCodeStep(user_input=user_input),
+                AnswerQuestionChroma(question=user_input),
+                EditFileChroma(request=user_input),
+                SimpleChatStep(user_input=user_input)
+            ])
 
         state = history.get_current()
 
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
new file mode 100644
index 00000000..c9bd760b
--- /dev/null
+++ b/continuedev/src/continuedev/server/state_manager.py
@@ -0,0 +1,21 @@
+from typing import Any, List, Tuple, Union
+from fastapi import WebSocket
+from pydantic import BaseModel
+from ..core.main import FullState
+
+# State updates represented as (path, replacement) pairs
+StateUpdate = Tuple[List[Union[str, int]], Any]
+
+
+class StateManager:
+    """
+    A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
+    """
+
+    def __init__(self, ws: WebSocket):
+        self.ws = ws
+
+    def _send_update(self, updates: List[StateUpdate]):
+        self.ws.send_json(
+            [update.dict() for update in updates]
+        )
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
new file mode 100644
index 00000000..817e10dd
--- /dev/null
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -0,0 +1,13 @@
+from textwrap import dedent
+from typing import List
+from ..core.main import Step
+from ..core.sdk import ContinueSDK
+from .main import MessageStep
+
+
+class SimpleChatStep(Step):
+    user_input: str
+
+    async def run(self, sdk: ContinueSDK):
+        # TODO: With history
+        self.description = await sdk.models.gpt35.complete(self.user_input)
diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py
new file mode 100644
index 00000000..411adc87
--- /dev/null
+++ b/continuedev/src/continuedev/steps/react.py
@@ -0,0 +1,37 @@
+from textwrap import dedent
+from typing import List
+from ..core.main import Step
+from ..core.sdk import ContinueSDK
+from .main import MessageStep
+
+
+class NLDecisionStep(Step):
+    user_input: str
+    steps: List[Step]
+
+    async def run(self, sdk: ContinueSDK):
+        step_descriptions = "\n".join([
+            f"- {step.name}: {step.description}"
+            for step in self.steps
+        ])
+        prompt = dedent(f"""\
+                        The following steps are available, in the format "- [step name]: [step description]":
+                        {step_descriptions}
+                        
+                        The user gave the following input:
+                        
+                        {self.user_input}
+                        
+                        Select the step which should be taken next. Say only the name of the selected step:""")
+
+        resp = (await sdk.models.gpt35.complete(prompt)).lower()
+
+        step_to_run = None
+        for step in self.steps:
+            if step.name in resp:
+                step_to_run = step
+
+        step_to_run = step_to_run or MessageStep(
+            message="Unable to decide the next step")
+
+        await sdk.run_step(step_to_run)
-- 
cgit v1.2.3-70-g09d2