diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 41 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 23 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/main.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 3 | 
10 files changed, 101 insertions, 11 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3c7fbdef..1a77ca64 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -2,8 +2,10 @@ from functools import cached_property  import traceback  import time  from typing import Any, Callable, Coroutine, Dict, List - +import os  from aiohttp import ClientPayloadError + +from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents  from ..libs.llm import LLM  from .observation import Observation, InternalErrorObservation @@ -59,7 +61,13 @@ class Autopilot(ContinueBaseModel):          keep_untouched = (cached_property,)      def get_full_state(self) -> FullState: -        return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, default_model=self.continue_sdk.config.default_model) +        return FullState( +            history=self.history, +            active=self._active, +            user_input_queue=self._main_user_input_queue, +            default_model=self.continue_sdk.config.default_model, +            highlighted_ranges=self._highlighted_ranges +        )      async def get_available_slash_commands(self) -> List[Dict]:          return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] @@ -124,6 +132,31 @@ class Autopilot(ContinueBaseModel):                          tb_step.step_name, {"output": output, **tb_step.params})                      await self._run_singular_step(step) +    _highlighted_ranges: List[RangeInFileWithContents] = [] + +    async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): +        workspace_path = self.continue_sdk.ide.workspace_directory +        for rif in range_in_files: +            rif.filepath = os.path.relpath(rif.filepath, workspace_path) + +        old_ranges = self._highlighted_ranges + range_in_files +        new_ranges = [] + +        while len(old_ranges) > 0: +            old_range = old_ranges.pop(0) +            found_overlap = False +            for i in range(len(new_ranges)): +                if old_range.filepath == new_ranges[i].filepath and old_range.range.overlaps_with(new_ranges[i].range): +                    new_ranges[i] = old_range.union(new_ranges[i]) +                    found_overlap = True +                    break + +            if not found_overlap: +                new_ranges.append(old_range) + +        self._highlighted_ranges = new_ranges +        await self.update_subscribers() +      _step_depth: int = 0      async def retry_at_index(self, index: int): @@ -135,6 +168,10 @@ class Autopilot(ContinueBaseModel):          self.history.timeline[index].deleted = True          await self.update_subscribers() +    async def delete_context_item_at_index(self, index: int): +        self._highlighted_ranges.pop(index) +        await self.update_subscribers() +      async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:          # Allow config to set disallowed steps          if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 4c6f4dc2..2d84801c 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -2,6 +2,7 @@ import json  from textwrap import dedent  from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union +from ..models.filesystem import RangeInFileWithContents  from ..models.main import ContinueBaseModel  from pydantic import validator  from .observation import Observation @@ -201,6 +202,7 @@ class FullState(ContinueBaseModel):      active: bool      user_input_queue: List[str]      default_model: str +    highlighted_ranges: List[RangeInFileWithContents]  class ContinueSDK: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index cfe2e436..50a14bed 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -179,7 +179,7 @@ class ContinueSDK(AbstractContinueSDK):      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history() -        highlighted_code = await self.ide.getHighlightedCode() +        highlighted_code = self.__autopilot._highlighted_ranges          preface = "The following code is highlighted" @@ -190,11 +190,10 @@ class ContinueSDK(AbstractContinueSDK):              if len(files) > 0:                  content = await self.ide.readFile(files[0])                  highlighted_code = [ -                    RangeInFile.from_entire_file(files[0], content)] +                    RangeInFileWithContents.from_entire_file(files[0], content)]          for rif in highlighted_code: -            code = await self.ide.readRangeInFile(rif) -            msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", +            msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",                                role="system", summary=f"{preface}: {rif.filepath}")              # Don't insert after latest user message or function call diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index b709dd21..df0b15d7 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -23,11 +23,34 @@ class RangeInFile(BaseModel):  class RangeInFileWithContents(RangeInFile): +    """A range in a file with the contents of the range."""      contents: str      def __hash__(self):          return hash((self.filepath, self.range, self.contents)) +    def union(self, other: "RangeInFileWithContents") -> "RangeInFileWithContents": +        assert self.filepath == other.filepath +        # Use a placeholder variable for self and swap it with other if other comes before self +        first = self +        second = other +        if other.range.start < self.range.start: +            first = other +            second = self + +        assert first.filepath == second.filepath + +        # Calculate union of contents +        num_overlapping_lines = first.range.end.line - second.range.start.line + 1 +        union_lines = first.contents.splitlines()[:-num_overlapping_lines] + \ +            second.contents.splitlines() + +        return RangeInFileWithContents( +            filepath=first.filepath, +            range=first.range.union(second.range), +            contents="\n".join(union_lines) +        ) +      @staticmethod      def from_entire_file(filepath: str, content: str) -> "RangeInFileWithContents":          lines = content.splitlines() diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index d5f6e650..fa736772 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -43,12 +43,23 @@ class Position(BaseModel):      def from_end_of_file(contents: str) -> "Position":          return Position.from_index(contents, len(contents)) +    def to_index(self, string: str) -> int: +        """Convert line and character to index in string""" +        lines = string.splitlines() +        return sum(map(len, lines[:self.line])) + self.character +  class Range(BaseModel):      """A range in a file. 0-indexed."""      start: Position      end: Position +    def __lt__(self, other: "Range") -> bool: +        return self.start < other.start or (self.start == other.start and self.end < other.end) + +    def __eq__(self, other: "Range") -> bool: +        return self.start == other.start and self.end == other.end +      def __hash__(self):          return hash((self.start, self.end)) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index c0178920..9a33fb6c 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -83,6 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  self.on_clear_history()              elif message_type == "delete_at_index":                  self.on_delete_at_index(data["index"]) +            elif message_type == "delete_context_item_at_index": +                self.on_delete_context_item_at_index(data["index"])          except Exception as e:              print(e) @@ -127,6 +129,11 @@ class GUIProtocolServer(AbstractGUIProtocolServer):      def on_delete_at_index(self, index: int):          asyncio.create_task(self.session.autopilot.delete_at_index(index)) +    def on_delete_context_item_at_index(self, index: int): +        asyncio.create_task( +            self.session.autopilot.delete_context_item_at_index(index) +        ) +  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 1f790991..e2685493 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -9,7 +9,7 @@ from uvicorn.main import Server  from ..libs.util.telemetry import capture_event  from ..libs.util.queue import AsyncSubscriptionQueue -from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RealFileSystem +from ..models.filesystem import FileSystem, RangeInFile, EditDiff, RangeInFileWithContents, RealFileSystem  from ..models.filesystem_edit import AddDirectory, AddFile, DeleteDirectory, DeleteFile, FileSystemEdit, FileEdit, FileEditWithFullContents, RenameDirectory, RenameFile, SequentialFileSystemEdit  from pydantic import BaseModel  from .gui import SessionManager, session_manager @@ -140,6 +140,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              fileEdits = list(                  map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))              self.onFileEdits(fileEdits) +        elif message_type == "highlightedCodePush": +            self.onHighlightedCodeUpdate( +                [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]])          elif message_type == "commandOutput":              output = data["output"]              self.onCommandOutput(output) @@ -234,6 +237,11 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              asyncio.create_task(                  session.autopilot.handle_command_output(output)) +    def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): +        for _, session in self.session_manager.sessions.items(): +            asyncio.create_task( +                session.autopilot.handle_highlighted_code(range_in_files)) +      # Request information. Session doesn't matter.      async def getOpenFiles(self) -> List[str]:          resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 7faf5563..de2eea27 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod, abstractproperty  from ..models.main import Traceback  from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff -from ..models.filesystem import RangeInFile +from ..models.filesystem import RangeInFile, RangeInFileWithContents  class AbstractIdeProtocolServer(ABC): @@ -91,6 +91,10 @@ class AbstractIdeProtocolServer(ABC):      async def runCommand(self, command: str) -> str:          """Run a command""" +    @abstractmethod +    def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): +        """Called when highlighted code is updated""" +      @abstractproperty      def workspace_directory(self) -> str:          """Get the workspace directory""" diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 34a97a17..9d556655 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -126,7 +126,7 @@ class RunTerminalCommandStep(Step):  class ViewDirectoryTreeStep(Step):      name: str = "View Directory Tree" -    description: str = "View the directory tree to learn which folder and files exist." +    description: str = "View the directory tree to learn which folder and files exist. You should always do this before adding new files."      async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]:          return f"Viewed the directory tree." diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 8f59bc4d..4eb2445c 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -318,8 +318,7 @@ class DefaultModelEditCodeStep(Step):              nonlocal current_block_start, current_line_in_file, original_lines, original_lines_below_previous_blocks, current_block_lines, indices_of_last_matched_lines, LINES_TO_MATCH_BEFORE_ENDING_BLOCK, offset_from_blocks              # Highlight the line to show progress -            # - len(current_block_lines) -            line_to_highlight = current_line_in_file +            line_to_highlight = current_line_in_file - len(current_block_lines)              await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand(                  line_to_highlight, 0, line_to_highlight, 0)), "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022") | 
