diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/custom_command.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/on_traceback.py | 8 | 
9 files changed, 41 insertions, 20 deletions
| diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 403e5417..4ea17f20 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -38,7 +38,8 @@ class ChatMessage(ContinueBaseModel):                  del d[key]          if not with_functions: -            d["role"] = "assistant" +            if d["role"] == "function": +                d["role"] = "assistant"              if "name" in d:                  del d["name"]              if "function_call" in d: diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 6ee2d03f..b8363df2 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -46,7 +46,7 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:              slash_command = parse_slash_command(custom_cmd.prompt, config)              if slash_command is not None:                  return slash_command -            return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command) +            return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name)      return None @@ -82,6 +82,6 @@ class DemoPolicy(Policy):              if custom_command is not None:                  return custom_command -            return SimpleChatStep(user_input=user_input) +            return SimpleChatStep()          return None diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 49513013..ed670799 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -192,7 +192,8 @@ class ContinueSDK(AbstractContinueSDK):      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history() -        highlighted_code = self.__autopilot._highlighted_ranges +        highlighted_code = [ +            hr.range for hr in self.__autopilot._highlighted_ranges]          preface = "The following code is highlighted" diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 1d51758e..e4a6266a 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -152,6 +152,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              self.onAcceptRejectDiff(data["accepted"])          elif message_type == "mainUserInput":              self.onMainUserInput(data["input"]) +        elif message_type == "deleteAtIndex": +            self.onDeleteAtIndex(data["index"])          elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:              self.sub_queue.post(message_type, data)          else: @@ -164,10 +166,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              "edit": file_edit.dict()          }) -    async def showDiff(self, filepath: str, replacement: str): +    async def showDiff(self, filepath: str, replacement: str, step_index: int):          await self._send_json("showDiff", {              "filepath": filepath, -            "replacement": replacement +            "replacement": replacement, +            "step_index": step_index          })      async def setFileOpen(self, filepath: str, open: bool = True): @@ -245,6 +248,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          for _, session in self.session_manager.sessions.items():              session.autopilot.handle_manual_edits(edits) +    def onDeleteAtIndex(self, index: int): +        for _, session in self.session_manager.sessions.items(): +            asyncio.create_task(session.autopilot.delete_at_index(index)) +      def onCommandOutput(self, output: str):          # Send the output to ALL autopilots.          # Maybe not ideal behavior diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 2e1f78d7..dfdca504 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -96,7 +96,11 @@ class AbstractIdeProtocolServer(ABC):          """Called when highlighted code is updated"""      @abstractmethod -    async def showDiff(self, filepath: str, replacement: str): +    def onDeleteAtIndex(self, index: int): +        """Called when a step is deleted at a given index""" + +    @abstractmethod +    async def showDiff(self, filepath: str, replacement: str, step_index: int):          """Show a diff"""      @abstractproperty diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index c26f8ff9..a10319d8 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -19,19 +19,15 @@ openai.api_key = OPENAI_API_KEY  class SimpleChatStep(Step): -    user_input: str      name: str = "Generating Response..."      manage_own_chat_context: bool = True      description: str = "" +    messages: List[ChatMessage] = None      async def run(self, sdk: ContinueSDK): -        if self.user_input.strip() == "": -            self.user_input = "Explain this code's function is a concise list of markdown bullets." -            self.description = "" -        await sdk.update_ui() -          completion = "" -        async for chunk in sdk.models.gpt4.stream_chat(await sdk.get_chat_context()): +        messages = self.messages or await sdk.get_chat_context() +        async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5):              if sdk.current_step_was_deleted():                  return diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index f22297ae..10853828 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -305,7 +305,10 @@ class DefaultModelEditCodeStep(Step):              full_suffix_lines = full_file_contents_lines[rif.range.end.line:]              new_file_contents = "\n".join(                  full_prefix_lines) + "\n" + completion + "\n" + "\n".join(full_suffix_lines) -            await sdk.ide.showDiff(rif.filepath, new_file_contents) + +            step_index = sdk.history.current_index + +            await sdk.ide.showDiff(rif.filepath, new_file_contents, step_index)          # Important state variables          # ------------------------- diff --git a/continuedev/src/continuedev/steps/custom_command.py b/continuedev/src/continuedev/steps/custom_command.py index 9d675091..5a56efb0 100644 --- a/continuedev/src/continuedev/steps/custom_command.py +++ b/continuedev/src/continuedev/steps/custom_command.py @@ -1,5 +1,6 @@  from ..core.main import Step  from ..core.sdk import ContinueSDK +from ..steps.core.core import UserInputStep  from ..steps.chat import ChatWithFunctions, SimpleChatStep @@ -7,6 +8,7 @@ class CustomCommandStep(Step):      name: str      prompt: str      user_input: str +    slash_command: str      hide: bool = True      async def describe(self): @@ -14,4 +16,11 @@ class CustomCommandStep(Step):      async def run(self, sdk: ContinueSDK):          prompt_user_input = f"Task: {self.prompt}. Additional info: {self.user_input}" -        await sdk.run_step(SimpleChatStep(user_input=prompt_user_input)) +        messages = await sdk.get_chat_context() +        # Find the last chat message with this slash command and replace it with the user input +        for i in range(len(messages) - 1, -1, -1): +            if messages[i].role == "user" and messages[i].content.startswith(self.slash_command): +                messages[i] = messages[i].copy( +                    update={"content": prompt_user_input}) +                break +        await sdk.run_step(SimpleChatStep(messages=messages)) diff --git a/continuedev/src/continuedev/steps/on_traceback.py b/continuedev/src/continuedev/steps/on_traceback.py index 3f8c5a76..efb4c703 100644 --- a/continuedev/src/continuedev/steps/on_traceback.py +++ b/continuedev/src/continuedev/steps/on_traceback.py @@ -1,4 +1,6 @@  import os + +from .core.core import UserInputStep  from ..core.main import ChatMessage, Step  from ..core.sdk import ContinueSDK  from .chat import SimpleChatStep @@ -21,7 +23,5 @@ class DefaultOnTracebackStep(Step):                          content=f"The contents of {seg}:\n```\n{file_contents}\n```",                          summary=""                      )) - -        await sdk.run_step(SimpleChatStep( -            name="Help With Traceback", -            user_input=f"""I got the following error, can you please help explain how to fix it?\n\n{self.output}""")) +        await sdk.run_step(UserInputStep(user_input=f"""I got the following error, can you please help explain how to fix it?\n\n{self.output}""")) +        await sdk.run_step(SimpleChatStep(name="Help With Traceback")) | 
