diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-16 22:08:23 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-16 22:08:23 -0700 |
commit | 6a8c48f72c697865d8889fe528aba1ad930011b5 (patch) | |
tree | d28d3fd8cc994452447ef19d23e5167ffc2c12c5 | |
parent | 874e63c89d45e14253925e9e85dda12bac629829 (diff) | |
parent | 7a86f6a41b16d94f676bf327d35fb768854becb4 (diff) | |
download | sncontinue-6a8c48f72c697865d8889fe528aba1ad930011b5.tar.gz sncontinue-6a8c48f72c697865d8889fe528aba1ad930011b5.tar.bz2 sncontinue-6a8c48f72c697865d8889fe528aba1ad930011b5.zip |
Merge branch 'main' of https://github.com/continuedev/continue
-rw-r--r-- | continuedev/src/continuedev/__main__.py | 31 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/lsp.py | 137 | ||||
-rw-r--r-- | continuedev/src/continuedev/headless/__init__.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/edit.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/strings.py | 13 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/main.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/refactor.py | 106 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/steps_on_startup.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 19 | ||||
-rw-r--r-- | docs/docs/walkthroughs/headless-mode.md | 37 |
12 files changed, 319 insertions, 73 deletions
diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py index 6dd03046..efd62d64 100644 --- a/continuedev/src/continuedev/__main__.py +++ b/continuedev/src/continuedev/__main__.py @@ -1,16 +1,31 @@ -import argparse +import asyncio +from typing import Optional +import typer + +from .headless import start_headless_session from .server.main import run_server +app = typer.Typer() -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", help="server port", type=int, default=65432) - parser.add_argument("--host", help="server host", type=str, default="127.0.0.1") - args = parser.parse_args() - run_server(port=args.port, host=args.host) +@app.command() +def main( + port: int = typer.Option(65432, help="server port"), + host: str = typer.Option("127.0.0.1", help="server host"), + config: Optional[str] = typer.Option( + None, help="The path to the configuration file" + ), + headless: bool = typer.Option(False, help="Run in headless mode"), +): + if headless: + loop = asyncio.get_event_loop() + loop.run_until_complete(start_headless_session(config=config)) + tasks = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*tasks)) + else: + run_server(port=port, host=host) if __name__ == "__main__": - main() + app() diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py index 181eea2e..0c906b22 100644 --- a/continuedev/src/continuedev/core/lsp.py +++ b/continuedev/src/continuedev/core/lsp.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import List, Optional +from typing import List, Literal, Optional import aiohttp from pydantic import BaseModel @@ -9,6 +9,7 @@ from pylsp.python_lsp import PythonLSPServer, start_ws_lang_server from ..libs.util.logging import logger from ..models.filesystem import RangeInFile from ..models.main import Position, Range +from ..server.meilisearch_server import kill_proc def filepath_to_uri(filename: str) -> str: @@ -17,7 +18,7 @@ def filepath_to_uri(filename: str) -> str: def uri_to_filepath(uri: str) -> str: if uri.startswith("file://"): - return uri.lstrip("file://") + return uri[7:] else: return uri @@ -26,6 +27,9 @@ PORT = 8099 class LSPClient: + ready: bool = False + lock: asyncio.Lock = asyncio.Lock() + def __init__(self, host: str, port: int, workspace_paths: List[str]): self.host = host self.port = port @@ -37,12 +41,18 @@ class LSPClient: print("Connecting") self.ws = await self.session.ws_connect(f"ws://{self.host}:{self.port}/") print("Connected") + self.ready = True async def send(self, data): await self.ws.send_json(data) async def recv(self): - return await self.ws.receive_json() + await self.lock.acquire() + + try: + return await self.ws.receive_json() + finally: + self.lock.release() async def close(self): await self.ws.close() @@ -237,9 +247,27 @@ class LSPClient: textDocument={"uri": filepath_to_uri(filepath)}, ) + async def find_references( + self, filepath: str, position: Position, include_declaration: bool = False + ): + return await self.call_method( + "textDocument/references", + textDocument={"uri": filepath_to_uri(filepath)}, + position=position.dict(), + context={"includeDeclaration": include_declaration}, + ) + + async def folding_range(self, filepath: str): + response = await self.call_method( + "textDocument/foldingRange", + textDocument={"uri": filepath_to_uri(filepath)}, + ) + return response["result"] + async def start_language_server() -> threading.Thread: try: + kill_proc(PORT) thread = threading.Thread( target=start_ws_lang_server, args=(PORT, False, PythonLSPServer), @@ -262,12 +290,23 @@ class DocumentSymbol(BaseModel): location: RangeInFile +class FoldingRange(BaseModel): + range: Range + kind: Optional[Literal["comment", "imports", "region"]] = None + + class ContinueLSPClient(BaseModel): workspace_dir: str lsp_client: LSPClient = None lsp_thread: Optional[threading.Thread] = None + @property + def ready(self): + if self.lsp_client is None: + return False + return self.lsp_client.ready + class Config: arbitrary_types_allowed = True @@ -287,6 +326,17 @@ class ContinueLSPClient(BaseModel): if self.lsp_thread: self.lsp_thread.join() + def location_to_range_in_file(self, location): + return RangeInFile( + filepath=uri_to_filepath(location["uri"]), + range=Range.from_shorthand( + location["range"]["start"]["line"], + location["range"]["start"]["character"], + location["range"]["end"]["line"], + location["range"]["end"]["character"], + ), + ) + async def goto_definition( self, position: Position, filename: str ) -> List[RangeInFile]: @@ -294,18 +344,17 @@ class ContinueLSPClient(BaseModel): filename, position, ) - return [ - RangeInFile( - filepath=uri_to_filepath(x.uri), - range=Range.from_shorthand( - x.range.start.line, - x.range.start.character, - x.range.end.line, - x.range.end.character, - ), - ) - for x in response - ] + return [self.location_to_range_in_file(x) for x in response] + + async def find_references( + self, position: Position, filename: str, include_declaration: bool = False + ) -> List[RangeInFile]: + response = await self.lsp_client.find_references( + filename, + position, + include_declaration=include_declaration, + ) + return [self.location_to_range_in_file(x) for x in response["result"]] async def document_symbol(self, filepath: str) -> List: response = await self.lsp_client.document_symbol(filepath) @@ -314,15 +363,55 @@ class ContinueLSPClient(BaseModel): name=x["name"], containerName=x["containerName"], kind=x["kind"], - location=RangeInFile( - filepath=uri_to_filepath(x["location"]["uri"]), - range=Range.from_shorthand( - x["location"]["range"]["start"]["line"], - x["location"]["range"]["start"]["character"], - x["location"]["range"]["end"]["line"], - x["location"]["range"]["end"]["character"], - ), - ), + location=self.location_to_range_in_file(x["location"]), ) for x in response["result"] ] + + async def folding_range(self, filepath: str) -> List[FoldingRange]: + response = await self.lsp_client.folding_range(filepath) + + return [ + FoldingRange( + range=Range.from_shorthand( + x["startLine"], + x.get("startCharacter", 0), + x["endLine"] if "endCharacter" in x else x["endLine"] + 1, + x.get("endCharacter", 0), + ), + kind=x.get("kind"), + ) + for x in response + ] + + async def get_enclosing_folding_range_of_position( + self, position: Position, filepath: str + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(position): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range + + async def get_enclosing_folding_range( + self, range_in_file: RangeInFile + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(range_in_file.filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(range_in_file.range.start) and r.range.contains( + range_in_file.range.end + ): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py index 4e46409a..27722ee7 100644 --- a/continuedev/src/continuedev/headless/__init__.py +++ b/continuedev/src/continuedev/headless/__init__.py @@ -4,6 +4,7 @@ from typing import Optional, Union import typer from ..core.config import ContinueConfig +from ..core.main import Step from ..server.session_manager import Session, session_manager from .headless_ide import LocalIdeProtocol @@ -21,21 +22,11 @@ async def start_headless_session( return await session_manager.new_session(ide, config=config) -async def async_main(config: Optional[str] = None): - await start_headless_session(config=config) +def run_step_headless(step: Step): + config = ContinueConfig() + config.steps_on_startup = [step] - -@app.command() -def main( - config: Optional[str] = typer.Option( - None, help="The path to the configuration file" - ) -): loop = asyncio.get_event_loop() - loop.run_until_complete(async_main(config)) + loop.run_until_complete(start_headless_session(config=config)) tasks = asyncio.all_tasks(loop) loop.run_until_complete(asyncio.gather(*tasks)) - - -if __name__ == "__main__": - app() diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py index 7da5a192..eaa694c5 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/edit.py +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -4,10 +4,10 @@ simplified_edit_prompt = dedent( """\ Consider the following code: ``` - {{code_to_edit}} + {{{code_to_edit}}} ``` Edit the code to perfectly satisfy the following user request: - {{user_input}} + {{{user_input}}} Output nothing except for the code. No code block, no English explanation, no start/end tags.""" ) @@ -15,11 +15,11 @@ simplest_edit_prompt = dedent( """\ Here is the code before editing: ``` - {{code_to_edit}} + {{{code_to_edit}}} ``` Here is the edit requested: - "{{user_input}}" + "{{{user_input}}}" Here is the code after editing:""" ) diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index d33c46c4..f2b6035f 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -25,6 +25,19 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp +def strip_code_block(s: str) -> str: + """ + Strips the code block from a string, if it has one. + """ + if s.startswith("```\n") and s.endswith("\n```"): + return s[4:-4] + elif s.startswith("```") and s.endswith("```"): + return s[3:-3] + elif s.startswith("`") and s.endswith("`"): + return s[1:-1] + return s + + def remove_quotes_and_escapes(output: str) -> str: """ Clean up the output of the completion API, removing unnecessary escapes and quotes diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index 3b056a2f..27244c4b 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -31,6 +31,9 @@ class RangeInFile(BaseModel): range = Range.from_entire_file(content) return RangeInFile(filepath=filepath, range=range) + def translated(self, lines: int): + return RangeInFile(filepath=self.filepath, range=self.range.translated(lines)) + class RangeInFileWithContents(RangeInFile): """A range in a file with the contents of the range.""" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 880fbfef..34c557e0 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -105,6 +105,17 @@ class Range(BaseModel): end=Position(line=self.end.line + 1, character=0), ) + def translated(self, lines: int): + return Range( + start=Position( + line=self.start.line + lines, character=self.start.character + ), + end=Position(line=self.end.line + lines, character=self.end.character), + ) + + def contains(self, position: Position) -> bool: + return self.start <= position and position <= self.end + @staticmethod def from_indices(string: str, start_index: int, end_index: int) -> "Range": return Range( diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index c4a61193..f4fbaf03 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -14,8 +14,7 @@ MAX_SIZE_IN_CHARS = 50_000 async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str: try: return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS] - except Exception as e: - print(f"Failed to read file: {e}") + except Exception as _: return None diff --git a/continuedev/src/continuedev/plugins/steps/refactor.py b/continuedev/src/continuedev/plugins/steps/refactor.py new file mode 100644 index 00000000..cfbce662 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/refactor.py @@ -0,0 +1,106 @@ +import asyncio +from typing import List + +from ...core.main import Step +from ...core.models import Models +from ...core.sdk import ContinueSDK +from ...libs.llm.prompts.edit import simplified_edit_prompt +from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block +from ...libs.util.templating import render_prompt_template +from ...models.filesystem import RangeInFile +from ...models.filesystem_edit import FileEdit +from ...models.main import PositionInFile, Range + + +class RefactorReferencesStep(Step): + name: str = "Refactor references of a symbol" + user_input: str + symbol_location: PositionInFile + + async def describe(self, models: Models): + return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`" + + async def run(self, sdk: ContinueSDK): + while sdk.lsp is None or not sdk.lsp.ready: + await asyncio.sleep(0.1) + + references = await sdk.lsp.find_references( + self.symbol_location.position, self.symbol_location.filepath, False + ) + await sdk.run_step( + ParallelEditStep(user_input=self.user_input, range_in_files=references) + ) + + +class ParallelEditStep(Step): + name: str = "Edit multiple ranges in parallel" + user_input: str + range_in_files: List[RangeInFile] + + hide: bool = True + + async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile): + # TODO: Can use folding info to get a more intuitively shaped range + expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file) + if ( + expanded_range is None + or expanded_range.range.start.line != range_in_file.range.start.line + ): + expanded_range = Range.from_shorthand( + range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0 + ) + else: + expanded_range = expanded_range.range + + new_rif = RangeInFile( + filepath=range_in_file.filepath, + range=expanded_range, + ) + code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif) + + # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit) + + prompt = render_prompt_template( + simplified_edit_prompt, + history=[], + other_data={ + "code_to_edit": code_to_edit, + "user_input": self.user_input, + }, + ) + print(prompt + "\n\n-------------------\n\n") + + new_code = await sdk.models.edit.complete(prompt=prompt) + new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n" + # new_code = ( + # "\n".join([common_whitespace + line for line in new_code.split("\n")]) + # + "\n" + # ) + + print(new_code + "\n\n-------------------\n\n") + + await sdk.ide.applyFileSystemEdit( + FileEdit( + filepath=range_in_file.filepath, + range=expanded_range, + replacement=new_code, + ) + ) + + async def edit_file(self, sdk: ContinueSDK, filepath: str): + ranges_in_file = [ + range_in_file + for range_in_file in self.range_in_files + if range_in_file.filepath == filepath + ] + # Sort in reverse order so that we don't mess up the ranges + ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True) + for i in range(len(ranges_in_file)): + await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i]) + + async def run(self, sdk: ContinueSDK): + tasks = [] + for filepath in set([rif.filepath for rif in self.range_in_files]): + tasks.append(self.edit_file(sdk=sdk, filepath=filepath)) + + await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py index d0058ffc..58d56703 100644 --- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -12,5 +12,8 @@ class StepsOnStartupStep(Step): steps_on_startup = sdk.config.steps_on_startup for step_type in steps_on_startup: - step = step_type() + if isinstance(step_type, Step): + step = step_type + else: + step = step_type() await sdk.run_step(step) diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index f3734470..40d46b18 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -77,14 +77,13 @@ async def ensure_meilisearch_installed() -> bool: except: pass existing_paths.remove(meilisearchPath) - + await download_meilisearch() # Clear the existing directories for p in existing_paths: shutil.rmtree(p, ignore_errors=True) - return False return True @@ -160,17 +159,25 @@ def stop_meilisearch(): import psutil + def kill_proc(port): for proc in psutil.process_iter(): try: - for conns in proc.connections(kind='inet'): + for conns in proc.connections(kind="inet"): if conns.laddr.port == port: - proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL + proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL except psutil.AccessDenied: - logger.warning(f"Failed to kill process on port {port}") + logger.warning(f"Failed to kill process on port {port} (access denied)") + return + except psutil.ZombieProcess: + logger.warning(f"Failed to kill process on port {port} (zombie process)") + return + except psutil.NoSuchProcess: + logger.warning(f"Failed to kill process on port {port} (no such process)") + return async def restart_meilisearch(): stop_meilisearch() kill_proc(7700) - await start_meilisearch()
\ No newline at end of file + await start_meilisearch() diff --git a/docs/docs/walkthroughs/headless-mode.md b/docs/docs/walkthroughs/headless-mode.md index 92382d77..d4f90264 100644 --- a/docs/docs/walkthroughs/headless-mode.md +++ b/docs/docs/walkthroughs/headless-mode.md @@ -5,23 +5,32 @@ To use headless mode: 1. `pip install continuedev` (using a virtual environment is recommended) -2. Create a config file (see the [`ContinueConfig` Reference](../reference/config.md) for all options) that includes the [Policy](../customization/other-configuration.md#custom-policies) you want to run -3. Import `continuedev` and call `start_headless_session` with either the path to your config file, or an instance of `ContinueConfig` +2. Import `continuedev` and call `run_step_headless` with the `Step` you would like to run Example: +Say you have the following file (`/path/to/file.py`): + ```python -from continuedev.headless import start_headless_session -from continuedev.core.config import ContinueConfig -from continuedev.core.models import Models -import asyncio - -config = ContinueConfig( - models=Models(...), - override_policy=MyPolicy() -) -asyncio.run(start_headless_session(config)) +def say_hello(name: str): + print(f"Hello, {name}") +``` + +and this function is imported and used in multiple places throughout your codebase. But the name parameter is new, and you need to change the function call everywhere it is used. You can use the script below to edit all usages of the function in your codebase: -# Alternatively, pass the path to a config file -asyncio.run(start_headless_session("/path/to/config.py")) +```python +from continuedev.headless import run_step_headless +from continuedev.models.main import Position, PositionInFile +from continuedev.plugins.steps.refactor import RefactorReferencesStep + +step = RefactorReferencesStep( + user_input="", + symbol_location=PositionInFile( + filepath="/path/to/file.py", + position=Position(line=0, character=5), + ), +) +run_step_headless(step=step) ``` + +Here we use Continue's built-in `RefactorReferencesStep`. By passing it the location (filepath and position) of the symbol (function, variable, etc.) that we want to update, Continue will automatically find all references to that symbol and prompt an LLM to make the edit requested in the `user_input` field. |