From 373c985296e9f4e408386c167b4206808d986046 Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Fri, 9 Jun 2023 14:35:21 -0400
Subject: touching up transform recipe, chat context

---
 continuedev/src/continuedev/core/abstract_sdk.py   | 12 ++++++--
 continuedev/src/continuedev/core/autopilot.py      |  1 +
 continuedev/src/continuedev/core/main.py           | 24 ++++++++++++++-
 continuedev/src/continuedev/core/policy.py         |  3 +-
 continuedev/src/continuedev/core/sdk.py            | 10 +++++-
 continuedev/src/continuedev/libs/llm/__init__.py   |  6 ++--
 .../src/continuedev/libs/llm/hf_inference_api.py   |  4 ++-
 continuedev/src/continuedev/libs/llm/openai.py     | 15 +++++----
 .../continuedev/recipes/AddTransformRecipe/main.py |  2 +-
 .../recipes/AddTransformRecipe/steps.py            | 19 +++++++++---
 .../recipes/CreatePipelineRecipe/main.py           |  2 +-
 .../recipes/CreatePipelineRecipe/steps.py          |  4 +--
 continuedev/src/continuedev/server/ide.py          |  2 +-
 continuedev/src/continuedev/steps/core/core.py     | 36 ++++++++++++++++++++--
 continuedev/src/continuedev/steps/main.py          | 13 +-------
 15 files changed, 116 insertions(+), 37 deletions(-)

(limited to 'continuedev')

diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 1c800875..417971cd 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -1,10 +1,10 @@
-from abc import ABC, abstractmethod
+from abc import ABC, abstractmethod, abstractproperty
 from typing import Coroutine, List, Union
 
 from .config import ContinueConfig
 from ..models.filesystem_edit import FileSystemEdit
 from .observation import Observation
-from .main import History, Step
+from .main import ChatMessage, History, Step, ChatMessageRole
 
 
 """
@@ -83,3 +83,11 @@ class AbstractContinueSDK(ABC):
     @abstractmethod
     def set_loading_message(self, message: str):
         pass
+
+    @abstractmethod
+    def add_chat_context(self, content: str, role: ChatMessageRole = "assistent"):
+        pass
+
+    @abstractproperty
+    def chat_context(self) -> List[ChatMessage]:
+        pass
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index b82e1fef..c979d53a 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -35,6 +35,7 @@ class Autopilot(ContinueBaseModel):
 
     class Config:
         arbitrary_types_allowed = True
+        keep_untouched = (cached_property,)
 
     def get_full_state(self) -> FullState:
         return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue)
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 37d80de3..19b36a6a 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -1,10 +1,18 @@
-from typing import Callable, Coroutine, Dict, Generator, List, Tuple, Union
+from textwrap import dedent
+from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union
 
 from ..models.main import ContinueBaseModel
 from pydantic import validator
 from ..libs.llm import LLM
 from .observation import Observation
 
+ChatMessageRole = Literal["assistant", "user", "system"]
+
+
+class ChatMessage(ContinueBaseModel):
+    role: ChatMessageRole
+    content: str
+
 
 class HistoryNode(ContinueBaseModel):
     """A point in history, a list of which make up History"""
@@ -12,12 +20,25 @@ class HistoryNode(ContinueBaseModel):
     observation: Union[Observation, None]
     depth: int
 
+    def to_chat_messages(self) -> List[ChatMessage]:
+        return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description)]
+
 
 class History(ContinueBaseModel):
     """A history of steps taken and their results"""
     timeline: List[HistoryNode]
     current_index: int
 
+    def to_chat_history(self) -> List[ChatMessage]:
+        msgs = []
+        for node in self.timeline:
+            if not node.step.hide:
+                msgs += [
+                    ChatMessage(role="assistant", content=msg)
+                    for msg in node.to_chat_messages()
+                ]
+        return msgs
+
     def add_node(self, node: HistoryNode):
         self.timeline.insert(self.current_index + 1, node)
         self.current_index += 1
@@ -113,6 +134,7 @@ class Step(ContinueBaseModel):
     description: Union[str, None] = None
 
     system_message: Union[str, None] = None
+    chat_context: List[ChatMessage] = []
 
     class Config:
         copy_on_model_validation = False
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index c3f1d188..7661f0c4 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -6,10 +6,11 @@ from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
 from ..recipes.AddTransformRecipe.main import AddTransformRecipe
 from .main import Step, Validator, History, Policy
 from .observation import Observation, TracebackObservation, UserInputObservation
-from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep, SetupContinueWorkspaceStep
+from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, EmptyStep, SetupContinueWorkspaceStep
 from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe
 from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep
 from ..steps.comment_code import CommentCodeStep
+from ..steps.core.core import MessageStep
 
 
 class DemoPolicy(Policy):
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index ea90a13a..11127361 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -14,7 +14,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
 from ..libs.llm.openai import OpenAI
 from .observation import Observation
 from ..server.ide_protocol import AbstractIdeProtocolServer
-from .main import Context, ContinueCustomException, History, Step
+from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole
 from ..steps.core.core import *
 
 
@@ -136,3 +136,11 @@ class ContinueSDK(AbstractContinueSDK):
 
     def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None):
         raise ContinueCustomException(message, title, with_step)
+
+    def add_chat_context(self, content: str, role: ChatMessageRole = "assistent"):
+        self.history.timeline[self.history.current_index].step.chat_context.append(
+            ChatMessage(content=content, role=role))
+
+    @property
+    def chat_context(self) -> List[ChatMessage]:
+        return self.history.to_chat_history()
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 6bae2222..24fd34be 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,4 +1,6 @@
-from typing import Union
+from typing import List, Union
+
+from ...core.main import ChatMessage
 from ...models.main import AbstractModel
 from pydantic import BaseModel
 
@@ -6,7 +8,7 @@ from pydantic import BaseModel
 class LLM(BaseModel):
     system_message: Union[str, None] = None
 
-    def complete(self, prompt: str, **kwargs):
+    def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
         """Return the completion of the text with the given temperature."""
         raise
 
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 734da160..1586c620 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,3 +1,5 @@
+from typing import List
+from ...core.main import ChatMessage
 from ..llm import LLM
 import requests
 
@@ -9,7 +11,7 @@ class HuggingFaceInferenceAPI(LLM):
     api_key: str
     model: str = "bigcode/starcoder"
 
-    def complete(self, prompt: str, **kwargs):
+    def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
         """Return the completion of the text with the given temperature."""
         API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
         headers = {
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 10801465..da8c5caf 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,6 +1,7 @@
 import asyncio
 import time
 from typing import Any, Dict, Generator, List, Union
+from ...core.main import ChatMessage
 import openai
 import aiohttp
 from ..llm import LLM
@@ -62,7 +63,7 @@ class OpenAI(LLM):
             for chunk in generator:
                 yield chunk.choices[0].text
 
-    def complete(self, prompt: str, **kwargs) -> str:
+    def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str:
         t1 = time.time()
 
         self.completion_count += 1
@@ -70,15 +71,17 @@ class OpenAI(LLM):
                 "frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs
 
         if args["model"] == "gpt-3.5-turbo":
-            messages = [{
-                "role": "user",
-                "content": prompt
-            }]
+            messages = []
             if self.system_message:
-                messages.insert(0, {
+                messages.append({
                     "role": "system",
                     "content": self.system_message
                 })
+            message += [msg.dict() for msg in with_history]
+            messages.append({
+                "role": "user",
+                "content": prompt
+            })
             resp = openai.ChatCompletion.create(
                 messages=messages,
                 **args,
diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py b/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py
index 5e05b587..e9a998e3 100644
--- a/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py
@@ -3,7 +3,7 @@ from textwrap import dedent
 from ...core.main import Step
 from ...core.sdk import ContinueSDK
 from ...steps.core.core import WaitForUserInputStep
-from ...steps.main import MessageStep
+from ...steps.core.core import MessageStep
 from .steps import SetUpChessPipelineStep, AddTransformStep
 
 
diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py b/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py
index f7f5a43b..7bb0fc23 100644
--- a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py
@@ -1,7 +1,9 @@
 import os
 from textwrap import dedent
 
-from ...steps.main import MessageStep
+from ...models.main import Range
+from ...models.filesystem import RangeInFile
+from ...steps.core.core import MessageStep
 from ...core.sdk import Models
 from ...core.observation import DictObservation
 from ...models.filesystem_edit import AddFile
@@ -26,7 +28,8 @@ class SetUpChessPipelineStep(Step):
             'source env/bin/activate',
             'pip install dlt',
             'dlt --non-interactive init chess duckdb',
-            'pip install -r requirements.txt'
+            'pip install -r requirements.txt',
+            'pip install pandas streamlit'  # Needed for the pipeline show step later
         ], name="Set up Python environment", description=dedent(f"""\
             Running the following commands:
             - `python3 -m venv env`: Create a Python virtual environment
@@ -44,7 +47,8 @@ class AddTransformStep(Step):
 
     async def run(self, sdk: ContinueSDK):
         source_name = 'chess'
-        filename = f'{source_name}.py'
+        filename = f'{source_name}_pipeline.py'
+        abs_filepath = os.path.join(sdk.ide.workspace_directory, filename)
 
         await sdk.run_step(MessageStep(message=dedent("""\
                 This step will customize your resource function with a transform of your choice:
@@ -52,6 +56,13 @@ class AddTransformStep(Step):
                 - Load the data into a local DuckDB instance
                 - Open up a Streamlit app for you to view the data"""), name="Write transformation function"))
 
+        # Open the file and highlight the function to be edited
+        await sdk.ide.setFileOpen(abs_filepath)
+        await sdk.ide.highlightCode(range_in_file=RangeInFile(
+            filepath=abs_filepath,
+            range=Range.from_shorthand(47, 0, 51, 0)
+        ))
+
         with open(os.path.join(os.path.dirname(__file__), 'dlt_transform_docs.md')) as f:
             dlt_transform_docs = f.read()
 
@@ -75,4 +86,4 @@ class AddTransformStep(Step):
         await sdk.run(f'python3 {filename}', name="Run the pipeline", description=f"Running `python3 {filename}` to load the data into a local DuckDB instance")
 
         # run a streamlit app to show the data
-        await sdk.run(f'dlt pipeline {source_name} show', name="Show data in a Streamlit app", description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.")
+        await sdk.run(f'dlt pipeline {source_name}_pipeline show', name="Show data in a Streamlit app", description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.")
diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
index 39e1ba42..818168ba 100644
--- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py
@@ -3,7 +3,7 @@ from textwrap import dedent
 from ...core.main import Step
 from ...core.sdk import ContinueSDK
 from ...steps.core.core import WaitForUserInputStep
-from ...steps.main import MessageStep
+from ...steps.core.core import MessageStep
 from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep
 
 
diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
index 3b9a8c85..ea40a058 100644
--- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
@@ -5,7 +5,7 @@ import time
 
 from ...models.main import Range
 from ...models.filesystem import RangeInFile
-from ...steps.main import MessageStep
+from ...steps.core.core import MessageStep
 from ...core.sdk import Models
 from ...core.observation import DictObservation, InternalErrorObservation
 from ...models.filesystem_edit import AddFile, FileEdit
@@ -51,7 +51,7 @@ class SetupPipelineStep(Step):
 
         # editing the resource function to call the requested API
         resource_function_range = Range.from_shorthand(15, 0, 29, 0)
-        await sdk.ide.highlightCode(RangeInFile(filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), range=resource_function_range), "#00ff0022")
+        await sdk.ide.highlightCode(RangeInFile(filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), range=resource_function_range))
 
         # sdk.set_loading_message("Writing code to call the API...")
         await sdk.edit_file(
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 5826f15f..f4ea1071 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -138,7 +138,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
             "sessionId": session_id
         })
 
-    async def highlightCode(self, range_in_file: RangeInFile, color: str):
+    async def highlightCode(self, range_in_file: RangeInFile, color: str = "#00ff0022"):
         await self._send_json("highlightCode", {
             "rangeInFile": range_in_file.dict(),
             "color": color
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index dfd765eb..5117d479 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -1,4 +1,5 @@
 # These steps are depended upon by ContinueSDK
+import os
 import subprocess
 from textwrap import dedent
 from typing import Coroutine, List, Union
@@ -23,6 +24,17 @@ class ReversibleStep(Step):
         raise NotImplementedError
 
 
+class MessageStep(Step):
+    name: str = "Message"
+    message: str
+
+    async def describe(self, models: Models) -> Coroutine[str, None, None]:
+        return self.message
+
+    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+        return TextObservation(text=self.message)
+
+
 class FileSystemEditStep(ReversibleStep):
     edit: FileSystemEdit
     _diff: Union[EditDiff, None] = None
@@ -38,6 +50,13 @@ class FileSystemEditStep(ReversibleStep):
         # Where and when should file saves happen?
 
 
+def output_contains_error(output: str) -> bool:
+    return "Traceback" in output or "SyntaxError" in output
+
+
+AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)"
+
+
 class ShellCommandsStep(Step):
     cmds: List[str]
     cwd: Union[str, None] = None
@@ -50,13 +69,26 @@ class ShellCommandsStep(Step):
             return f"Error when running shell commands:\n```\n{self._err_text}\n```"
 
         cmds_str = "\n".join(self.cmds)
-        return (await models.gpt35()).complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
+        return models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
 
     async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
         cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
 
         for cmd in self.cmds:
             output = await sdk.ide.runCommand(cmd)
+            if output is not None and output_contains_error(output):
+                suggestion = sdk.models.gpt35.complete(dedent(f"""\
+                    While running the command `{cmd}`, the following error occurred:
+
+                    ```ascii
+                    {output}
+                    ```
+
+                    This is a brief summary of the error followed by a suggestion on how it can be fixed:"""), with_context=sdk.chat_context)
+
+                sdk.raise_exception(
+                    title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=suggestion)
+                )
 
         return TextObservation(text=output)
 
@@ -116,7 +148,7 @@ class Gpt35EditCodeStep(Step):
     _prompt_and_completion: str = ""
 
     async def describe(self, models: Models) -> Coroutine[str, None, None]:
-        return (await models.gpt35()).complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+        return models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
 
     async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
         rif_with_contents = []
diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py
index 81a1e3a9..24335b4f 100644
--- a/continuedev/src/continuedev/steps/main.py
+++ b/continuedev/src/continuedev/steps/main.py
@@ -212,7 +212,7 @@ class StarCoderEditHighlightedCodeStep(Step):
     _prompt_and_completion: str = ""
 
     async def describe(self, models: Models) -> Coroutine[str, None, None]:
-        return (await models.gpt35()).complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+        return models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
 
     async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
         range_in_files = await sdk.ide.getHighlightedCode()
@@ -317,17 +317,6 @@ class SolveTracebackStep(Step):
         return None
 
 
-class MessageStep(Step):
-    name: str = "Message"
-    message: str
-
-    async def describe(self, models: Models) -> Coroutine[str, None, None]:
-        return self.message
-
-    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
-        return TextObservation(text=self.message)
-
-
 class EmptyStep(Step):
     hide: bool = True
 
-- 
cgit v1.2.3-70-g09d2