diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/__main__.py | 31 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/lsp.py | 137 | ||||
| -rw-r--r-- | continuedev/src/continuedev/headless/__init__.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/edit.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/strings.py | 13 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/main.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/refactor.py | 106 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/steps_on_startup.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 19 | 
11 files changed, 296 insertions, 59 deletions
| diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py index 6dd03046..efd62d64 100644 --- a/continuedev/src/continuedev/__main__.py +++ b/continuedev/src/continuedev/__main__.py @@ -1,16 +1,31 @@ -import argparse +import asyncio +from typing import Optional +import typer + +from .headless import start_headless_session  from .server.main import run_server +app = typer.Typer() -def main(): -    parser = argparse.ArgumentParser() -    parser.add_argument("-p", "--port", help="server port", type=int, default=65432) -    parser.add_argument("--host", help="server host", type=str, default="127.0.0.1") -    args = parser.parse_args() -    run_server(port=args.port, host=args.host) +@app.command() +def main( +    port: int = typer.Option(65432, help="server port"), +    host: str = typer.Option("127.0.0.1", help="server host"), +    config: Optional[str] = typer.Option( +        None, help="The path to the configuration file" +    ), +    headless: bool = typer.Option(False, help="Run in headless mode"), +): +    if headless: +        loop = asyncio.get_event_loop() +        loop.run_until_complete(start_headless_session(config=config)) +        tasks = asyncio.all_tasks(loop) +        loop.run_until_complete(asyncio.gather(*tasks)) +    else: +        run_server(port=port, host=host)  if __name__ == "__main__": -    main() +    app() diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py index 181eea2e..0c906b22 100644 --- a/continuedev/src/continuedev/core/lsp.py +++ b/continuedev/src/continuedev/core/lsp.py @@ -1,6 +1,6 @@  import asyncio  import threading -from typing import List, Optional +from typing import List, Literal, Optional  import aiohttp  from pydantic import BaseModel @@ -9,6 +9,7 @@ from pylsp.python_lsp import PythonLSPServer, start_ws_lang_server  from ..libs.util.logging import logger  from ..models.filesystem import RangeInFile  from ..models.main import Position, Range +from ..server.meilisearch_server import kill_proc  def filepath_to_uri(filename: str) -> str: @@ -17,7 +18,7 @@ def filepath_to_uri(filename: str) -> str:  def uri_to_filepath(uri: str) -> str:      if uri.startswith("file://"): -        return uri.lstrip("file://") +        return uri[7:]      else:          return uri @@ -26,6 +27,9 @@ PORT = 8099  class LSPClient: +    ready: bool = False +    lock: asyncio.Lock = asyncio.Lock() +      def __init__(self, host: str, port: int, workspace_paths: List[str]):          self.host = host          self.port = port @@ -37,12 +41,18 @@ class LSPClient:          print("Connecting")          self.ws = await self.session.ws_connect(f"ws://{self.host}:{self.port}/")          print("Connected") +        self.ready = True      async def send(self, data):          await self.ws.send_json(data)      async def recv(self): -        return await self.ws.receive_json() +        await self.lock.acquire() + +        try: +            return await self.ws.receive_json() +        finally: +            self.lock.release()      async def close(self):          await self.ws.close() @@ -237,9 +247,27 @@ class LSPClient:              textDocument={"uri": filepath_to_uri(filepath)},          ) +    async def find_references( +        self, filepath: str, position: Position, include_declaration: bool = False +    ): +        return await self.call_method( +            "textDocument/references", +            textDocument={"uri": filepath_to_uri(filepath)}, +            position=position.dict(), +            context={"includeDeclaration": include_declaration}, +        ) + +    async def folding_range(self, filepath: str): +        response = await self.call_method( +            "textDocument/foldingRange", +            textDocument={"uri": filepath_to_uri(filepath)}, +        ) +        return response["result"] +  async def start_language_server() -> threading.Thread:      try: +        kill_proc(PORT)          thread = threading.Thread(              target=start_ws_lang_server,              args=(PORT, False, PythonLSPServer), @@ -262,12 +290,23 @@ class DocumentSymbol(BaseModel):      location: RangeInFile +class FoldingRange(BaseModel): +    range: Range +    kind: Optional[Literal["comment", "imports", "region"]] = None + +  class ContinueLSPClient(BaseModel):      workspace_dir: str      lsp_client: LSPClient = None      lsp_thread: Optional[threading.Thread] = None +    @property +    def ready(self): +        if self.lsp_client is None: +            return False +        return self.lsp_client.ready +      class Config:          arbitrary_types_allowed = True @@ -287,6 +326,17 @@ class ContinueLSPClient(BaseModel):          if self.lsp_thread:              self.lsp_thread.join() +    def location_to_range_in_file(self, location): +        return RangeInFile( +            filepath=uri_to_filepath(location["uri"]), +            range=Range.from_shorthand( +                location["range"]["start"]["line"], +                location["range"]["start"]["character"], +                location["range"]["end"]["line"], +                location["range"]["end"]["character"], +            ), +        ) +      async def goto_definition(          self, position: Position, filename: str      ) -> List[RangeInFile]: @@ -294,18 +344,17 @@ class ContinueLSPClient(BaseModel):              filename,              position,          ) -        return [ -            RangeInFile( -                filepath=uri_to_filepath(x.uri), -                range=Range.from_shorthand( -                    x.range.start.line, -                    x.range.start.character, -                    x.range.end.line, -                    x.range.end.character, -                ), -            ) -            for x in response -        ] +        return [self.location_to_range_in_file(x) for x in response] + +    async def find_references( +        self, position: Position, filename: str, include_declaration: bool = False +    ) -> List[RangeInFile]: +        response = await self.lsp_client.find_references( +            filename, +            position, +            include_declaration=include_declaration, +        ) +        return [self.location_to_range_in_file(x) for x in response["result"]]      async def document_symbol(self, filepath: str) -> List:          response = await self.lsp_client.document_symbol(filepath) @@ -314,15 +363,55 @@ class ContinueLSPClient(BaseModel):                  name=x["name"],                  containerName=x["containerName"],                  kind=x["kind"], -                location=RangeInFile( -                    filepath=uri_to_filepath(x["location"]["uri"]), -                    range=Range.from_shorthand( -                        x["location"]["range"]["start"]["line"], -                        x["location"]["range"]["start"]["character"], -                        x["location"]["range"]["end"]["line"], -                        x["location"]["range"]["end"]["character"], -                    ), -                ), +                location=self.location_to_range_in_file(x["location"]),              )              for x in response["result"]          ] + +    async def folding_range(self, filepath: str) -> List[FoldingRange]: +        response = await self.lsp_client.folding_range(filepath) + +        return [ +            FoldingRange( +                range=Range.from_shorthand( +                    x["startLine"], +                    x.get("startCharacter", 0), +                    x["endLine"] if "endCharacter" in x else x["endLine"] + 1, +                    x.get("endCharacter", 0), +                ), +                kind=x.get("kind"), +            ) +            for x in response +        ] + +    async def get_enclosing_folding_range_of_position( +        self, position: Position, filepath: str +    ) -> Optional[FoldingRange]: +        ranges = await self.folding_range(filepath) + +        max_start_position = Position(line=0, character=0) +        max_range = None +        for r in ranges: +            if r.range.contains(position): +                if r.range.start > max_start_position: +                    max_start_position = r.range.start +                    max_range = r + +        return max_range + +    async def get_enclosing_folding_range( +        self, range_in_file: RangeInFile +    ) -> Optional[FoldingRange]: +        ranges = await self.folding_range(range_in_file.filepath) + +        max_start_position = Position(line=0, character=0) +        max_range = None +        for r in ranges: +            if r.range.contains(range_in_file.range.start) and r.range.contains( +                range_in_file.range.end +            ): +                if r.range.start > max_start_position: +                    max_start_position = r.range.start +                    max_range = r + +        return max_range diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py index 4e46409a..27722ee7 100644 --- a/continuedev/src/continuedev/headless/__init__.py +++ b/continuedev/src/continuedev/headless/__init__.py @@ -4,6 +4,7 @@ from typing import Optional, Union  import typer  from ..core.config import ContinueConfig +from ..core.main import Step  from ..server.session_manager import Session, session_manager  from .headless_ide import LocalIdeProtocol @@ -21,21 +22,11 @@ async def start_headless_session(      return await session_manager.new_session(ide, config=config) -async def async_main(config: Optional[str] = None): -    await start_headless_session(config=config) +def run_step_headless(step: Step): +    config = ContinueConfig() +    config.steps_on_startup = [step] - -@app.command() -def main( -    config: Optional[str] = typer.Option( -        None, help="The path to the configuration file" -    ) -):      loop = asyncio.get_event_loop() -    loop.run_until_complete(async_main(config)) +    loop.run_until_complete(start_headless_session(config=config))      tasks = asyncio.all_tasks(loop)      loop.run_until_complete(asyncio.gather(*tasks)) - - -if __name__ == "__main__": -    app() diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py index 7da5a192..eaa694c5 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/edit.py +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -4,10 +4,10 @@ simplified_edit_prompt = dedent(      """\              Consider the following code:              ``` -            {{code_to_edit}} +            {{{code_to_edit}}}              ```              Edit the code to perfectly satisfy the following user request: -            {{user_input}} +            {{{user_input}}}              Output nothing except for the code. No code block, no English explanation, no start/end tags."""  ) @@ -15,11 +15,11 @@ simplest_edit_prompt = dedent(      """\              Here is the code before editing:              ``` -            {{code_to_edit}} +            {{{code_to_edit}}}              ```              Here is the edit requested: -            "{{user_input}}" +            "{{{user_input}}}"              Here is the code after editing:"""  ) diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index d33c46c4..f2b6035f 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -25,6 +25,19 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:      return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp +def strip_code_block(s: str) -> str: +    """ +    Strips the code block from a string, if it has one. +    """ +    if s.startswith("```\n") and s.endswith("\n```"): +        return s[4:-4] +    elif s.startswith("```") and s.endswith("```"): +        return s[3:-3] +    elif s.startswith("`") and s.endswith("`"): +        return s[1:-1] +    return s + +  def remove_quotes_and_escapes(output: str) -> str:      """      Clean up the output of the completion API, removing unnecessary escapes and quotes diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index 3b056a2f..27244c4b 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -31,6 +31,9 @@ class RangeInFile(BaseModel):          range = Range.from_entire_file(content)          return RangeInFile(filepath=filepath, range=range) +    def translated(self, lines: int): +        return RangeInFile(filepath=self.filepath, range=self.range.translated(lines)) +  class RangeInFileWithContents(RangeInFile):      """A range in a file with the contents of the range.""" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 880fbfef..34c557e0 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -105,6 +105,17 @@ class Range(BaseModel):              end=Position(line=self.end.line + 1, character=0),          ) +    def translated(self, lines: int): +        return Range( +            start=Position( +                line=self.start.line + lines, character=self.start.character +            ), +            end=Position(line=self.end.line + lines, character=self.end.character), +        ) + +    def contains(self, position: Position) -> bool: +        return self.start <= position and position <= self.end +      @staticmethod      def from_indices(string: str, start_index: int, end_index: int) -> "Range":          return Range( diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index c4a61193..f4fbaf03 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -14,8 +14,7 @@ MAX_SIZE_IN_CHARS = 50_000  async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str:      try:          return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS] -    except Exception as e: -        print(f"Failed to read file: {e}") +    except Exception as _:          return None diff --git a/continuedev/src/continuedev/plugins/steps/refactor.py b/continuedev/src/continuedev/plugins/steps/refactor.py new file mode 100644 index 00000000..cfbce662 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/refactor.py @@ -0,0 +1,106 @@ +import asyncio +from typing import List + +from ...core.main import Step +from ...core.models import Models +from ...core.sdk import ContinueSDK +from ...libs.llm.prompts.edit import simplified_edit_prompt +from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block +from ...libs.util.templating import render_prompt_template +from ...models.filesystem import RangeInFile +from ...models.filesystem_edit import FileEdit +from ...models.main import PositionInFile, Range + + +class RefactorReferencesStep(Step): +    name: str = "Refactor references of a symbol" +    user_input: str +    symbol_location: PositionInFile + +    async def describe(self, models: Models): +        return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`" + +    async def run(self, sdk: ContinueSDK): +        while sdk.lsp is None or not sdk.lsp.ready: +            await asyncio.sleep(0.1) + +        references = await sdk.lsp.find_references( +            self.symbol_location.position, self.symbol_location.filepath, False +        ) +        await sdk.run_step( +            ParallelEditStep(user_input=self.user_input, range_in_files=references) +        ) + + +class ParallelEditStep(Step): +    name: str = "Edit multiple ranges in parallel" +    user_input: str +    range_in_files: List[RangeInFile] + +    hide: bool = True + +    async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile): +        # TODO: Can use folding info to get a more intuitively shaped range +        expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file) +        if ( +            expanded_range is None +            or expanded_range.range.start.line != range_in_file.range.start.line +        ): +            expanded_range = Range.from_shorthand( +                range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0 +            ) +        else: +            expanded_range = expanded_range.range + +        new_rif = RangeInFile( +            filepath=range_in_file.filepath, +            range=expanded_range, +        ) +        code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif) + +        # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit) + +        prompt = render_prompt_template( +            simplified_edit_prompt, +            history=[], +            other_data={ +                "code_to_edit": code_to_edit, +                "user_input": self.user_input, +            }, +        ) +        print(prompt + "\n\n-------------------\n\n") + +        new_code = await sdk.models.edit.complete(prompt=prompt) +        new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n" +        # new_code = ( +        #     "\n".join([common_whitespace + line for line in new_code.split("\n")]) +        #     + "\n" +        # ) + +        print(new_code + "\n\n-------------------\n\n") + +        await sdk.ide.applyFileSystemEdit( +            FileEdit( +                filepath=range_in_file.filepath, +                range=expanded_range, +                replacement=new_code, +            ) +        ) + +    async def edit_file(self, sdk: ContinueSDK, filepath: str): +        ranges_in_file = [ +            range_in_file +            for range_in_file in self.range_in_files +            if range_in_file.filepath == filepath +        ] +        # Sort in reverse order so that we don't mess up the ranges +        ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True) +        for i in range(len(ranges_in_file)): +            await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i]) + +    async def run(self, sdk: ContinueSDK): +        tasks = [] +        for filepath in set([rif.filepath for rif in self.range_in_files]): +            tasks.append(self.edit_file(sdk=sdk, filepath=filepath)) + +        await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py index d0058ffc..58d56703 100644 --- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -12,5 +12,8 @@ class StepsOnStartupStep(Step):          steps_on_startup = sdk.config.steps_on_startup          for step_type in steps_on_startup: -            step = step_type() +            if isinstance(step_type, Step): +                step = step_type +            else: +                step = step_type()              await sdk.run_step(step) diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index f3734470..40d46b18 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -77,14 +77,13 @@ async def ensure_meilisearch_installed() -> bool:              except:                  pass              existing_paths.remove(meilisearchPath) -             +              await download_meilisearch()          # Clear the existing directories          for p in existing_paths:              shutil.rmtree(p, ignore_errors=True) -          return False      return True @@ -160,17 +159,25 @@ def stop_meilisearch():  import psutil +  def kill_proc(port):      for proc in psutil.process_iter():          try: -            for conns in proc.connections(kind='inet'): +            for conns in proc.connections(kind="inet"):                  if conns.laddr.port == port: -                    proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL +                    proc.send_signal(psutil.signal.SIGTERM)  # or SIGKILL          except psutil.AccessDenied: -            logger.warning(f"Failed to kill process on port {port}") +            logger.warning(f"Failed to kill process on port {port} (access denied)") +            return +        except psutil.ZombieProcess: +            logger.warning(f"Failed to kill process on port {port} (zombie process)") +            return +        except psutil.NoSuchProcess: +            logger.warning(f"Failed to kill process on port {port} (no such process)") +            return  async def restart_meilisearch():      stop_meilisearch()      kill_proc(7700) -    await start_meilisearch()
\ No newline at end of file +    await start_meilisearch() | 
