diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-16 22:08:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-16 22:08:01 -0700 |
commit | 7a86f6a41b16d94f676bf327d35fb768854becb4 (patch) | |
tree | d28d3fd8cc994452447ef19d23e5167ffc2c12c5 /continuedev | |
parent | dfbae3f6add30b47d2bd0ba34be89af60d9ab660 (diff) | |
download | sncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.tar.gz sncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.tar.bz2 sncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.zip |
Refactor helper (#481)
* feat: :sparkles: add stop_tokens option to LLM
* work on refactoring in headless mode
* feat: :sparkles: headless mode refactors
* chore: :fire: remove test.py
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/__main__.py | 31 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/lsp.py | 137 | ||||
-rw-r--r-- | continuedev/src/continuedev/headless/__init__.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/edit.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/strings.py | 13 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/filesystem.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/main.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/refactor.py | 106 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/steps_on_startup.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 19 |
12 files changed, 303 insertions, 62 deletions
diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py index 6dd03046..efd62d64 100644 --- a/continuedev/src/continuedev/__main__.py +++ b/continuedev/src/continuedev/__main__.py @@ -1,16 +1,31 @@ -import argparse +import asyncio +from typing import Optional +import typer + +from .headless import start_headless_session from .server.main import run_server +app = typer.Typer() -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", help="server port", type=int, default=65432) - parser.add_argument("--host", help="server host", type=str, default="127.0.0.1") - args = parser.parse_args() - run_server(port=args.port, host=args.host) +@app.command() +def main( + port: int = typer.Option(65432, help="server port"), + host: str = typer.Option("127.0.0.1", help="server host"), + config: Optional[str] = typer.Option( + None, help="The path to the configuration file" + ), + headless: bool = typer.Option(False, help="Run in headless mode"), +): + if headless: + loop = asyncio.get_event_loop() + loop.run_until_complete(start_headless_session(config=config)) + tasks = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*tasks)) + else: + run_server(port=port, host=host) if __name__ == "__main__": - main() + app() diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py index 181eea2e..0c906b22 100644 --- a/continuedev/src/continuedev/core/lsp.py +++ b/continuedev/src/continuedev/core/lsp.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import List, Optional +from typing import List, Literal, Optional import aiohttp from pydantic import BaseModel @@ -9,6 +9,7 @@ from pylsp.python_lsp import PythonLSPServer, start_ws_lang_server from ..libs.util.logging import logger from ..models.filesystem import RangeInFile from ..models.main import Position, Range +from ..server.meilisearch_server import kill_proc def filepath_to_uri(filename: str) -> str: @@ -17,7 +18,7 @@ def filepath_to_uri(filename: str) -> str: def uri_to_filepath(uri: str) -> str: if uri.startswith("file://"): - return uri.lstrip("file://") + return uri[7:] else: return uri @@ -26,6 +27,9 @@ PORT = 8099 class LSPClient: + ready: bool = False + lock: asyncio.Lock = asyncio.Lock() + def __init__(self, host: str, port: int, workspace_paths: List[str]): self.host = host self.port = port @@ -37,12 +41,18 @@ class LSPClient: print("Connecting") self.ws = await self.session.ws_connect(f"ws://{self.host}:{self.port}/") print("Connected") + self.ready = True async def send(self, data): await self.ws.send_json(data) async def recv(self): - return await self.ws.receive_json() + await self.lock.acquire() + + try: + return await self.ws.receive_json() + finally: + self.lock.release() async def close(self): await self.ws.close() @@ -237,9 +247,27 @@ class LSPClient: textDocument={"uri": filepath_to_uri(filepath)}, ) + async def find_references( + self, filepath: str, position: Position, include_declaration: bool = False + ): + return await self.call_method( + "textDocument/references", + textDocument={"uri": filepath_to_uri(filepath)}, + position=position.dict(), + context={"includeDeclaration": include_declaration}, + ) + + async def folding_range(self, filepath: str): + response = await self.call_method( + "textDocument/foldingRange", + textDocument={"uri": filepath_to_uri(filepath)}, + ) + return response["result"] + async def start_language_server() -> threading.Thread: try: + kill_proc(PORT) thread = threading.Thread( target=start_ws_lang_server, args=(PORT, False, PythonLSPServer), @@ -262,12 +290,23 @@ class DocumentSymbol(BaseModel): location: RangeInFile +class FoldingRange(BaseModel): + range: Range + kind: Optional[Literal["comment", "imports", "region"]] = None + + class ContinueLSPClient(BaseModel): workspace_dir: str lsp_client: LSPClient = None lsp_thread: Optional[threading.Thread] = None + @property + def ready(self): + if self.lsp_client is None: + return False + return self.lsp_client.ready + class Config: arbitrary_types_allowed = True @@ -287,6 +326,17 @@ class ContinueLSPClient(BaseModel): if self.lsp_thread: self.lsp_thread.join() + def location_to_range_in_file(self, location): + return RangeInFile( + filepath=uri_to_filepath(location["uri"]), + range=Range.from_shorthand( + location["range"]["start"]["line"], + location["range"]["start"]["character"], + location["range"]["end"]["line"], + location["range"]["end"]["character"], + ), + ) + async def goto_definition( self, position: Position, filename: str ) -> List[RangeInFile]: @@ -294,18 +344,17 @@ class ContinueLSPClient(BaseModel): filename, position, ) - return [ - RangeInFile( - filepath=uri_to_filepath(x.uri), - range=Range.from_shorthand( - x.range.start.line, - x.range.start.character, - x.range.end.line, - x.range.end.character, - ), - ) - for x in response - ] + return [self.location_to_range_in_file(x) for x in response] + + async def find_references( + self, position: Position, filename: str, include_declaration: bool = False + ) -> List[RangeInFile]: + response = await self.lsp_client.find_references( + filename, + position, + include_declaration=include_declaration, + ) + return [self.location_to_range_in_file(x) for x in response["result"]] async def document_symbol(self, filepath: str) -> List: response = await self.lsp_client.document_symbol(filepath) @@ -314,15 +363,55 @@ class ContinueLSPClient(BaseModel): name=x["name"], containerName=x["containerName"], kind=x["kind"], - location=RangeInFile( - filepath=uri_to_filepath(x["location"]["uri"]), - range=Range.from_shorthand( - x["location"]["range"]["start"]["line"], - x["location"]["range"]["start"]["character"], - x["location"]["range"]["end"]["line"], - x["location"]["range"]["end"]["character"], - ), - ), + location=self.location_to_range_in_file(x["location"]), ) for x in response["result"] ] + + async def folding_range(self, filepath: str) -> List[FoldingRange]: + response = await self.lsp_client.folding_range(filepath) + + return [ + FoldingRange( + range=Range.from_shorthand( + x["startLine"], + x.get("startCharacter", 0), + x["endLine"] if "endCharacter" in x else x["endLine"] + 1, + x.get("endCharacter", 0), + ), + kind=x.get("kind"), + ) + for x in response + ] + + async def get_enclosing_folding_range_of_position( + self, position: Position, filepath: str + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(position): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range + + async def get_enclosing_folding_range( + self, range_in_file: RangeInFile + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(range_in_file.filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(range_in_file.range.start) and r.range.contains( + range_in_file.range.end + ): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py index 4e46409a..27722ee7 100644 --- a/continuedev/src/continuedev/headless/__init__.py +++ b/continuedev/src/continuedev/headless/__init__.py @@ -4,6 +4,7 @@ from typing import Optional, Union import typer from ..core.config import ContinueConfig +from ..core.main import Step from ..server.session_manager import Session, session_manager from .headless_ide import LocalIdeProtocol @@ -21,21 +22,11 @@ async def start_headless_session( return await session_manager.new_session(ide, config=config) -async def async_main(config: Optional[str] = None): - await start_headless_session(config=config) +def run_step_headless(step: Step): + config = ContinueConfig() + config.steps_on_startup = [step] - -@app.command() -def main( - config: Optional[str] = typer.Option( - None, help="The path to the configuration file" - ) -): loop = asyncio.get_event_loop() - loop.run_until_complete(async_main(config)) + loop.run_until_complete(start_headless_session(config=config)) tasks = asyncio.all_tasks(loop) loop.run_until_complete(asyncio.gather(*tasks)) - - -if __name__ == "__main__": - app() diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index baeb9d1a..b2eecab6 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -68,6 +68,10 @@ class LLM(ContinueBaseModel): ..., description="The name of the model to be used (e.g. gpt-4, codellama)" ) + stop_tokens: Optional[List[str]] = Field( + None, description="Tokens that will stop the completion." + ) + timeout: Optional[int] = Field( 300, description="Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts.", @@ -204,7 +208,7 @@ class LLM(ContinueBaseModel): top_k=top_k, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, - stop=stop, + stop=stop or self.stop_tokens, max_tokens=max_tokens, functions=functions, ) @@ -251,7 +255,7 @@ class LLM(ContinueBaseModel): top_k=top_k, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, - stop=stop, + stop=stop or self.stop_tokens, max_tokens=max_tokens, functions=functions, ) @@ -296,7 +300,7 @@ class LLM(ContinueBaseModel): top_k=top_k, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, - stop=stop, + stop=stop or self.stop_tokens, max_tokens=max_tokens, functions=functions, ) diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py index 7da5a192..eaa694c5 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/edit.py +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -4,10 +4,10 @@ simplified_edit_prompt = dedent( """\ Consider the following code: ``` - {{code_to_edit}} + {{{code_to_edit}}} ``` Edit the code to perfectly satisfy the following user request: - {{user_input}} + {{{user_input}}} Output nothing except for the code. No code block, no English explanation, no start/end tags.""" ) @@ -15,11 +15,11 @@ simplest_edit_prompt = dedent( """\ Here is the code before editing: ``` - {{code_to_edit}} + {{{code_to_edit}}} ``` Here is the edit requested: - "{{user_input}}" + "{{{user_input}}}" Here is the code after editing:""" ) diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index d33c46c4..f2b6035f 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -25,6 +25,19 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp +def strip_code_block(s: str) -> str: + """ + Strips the code block from a string, if it has one. + """ + if s.startswith("```\n") and s.endswith("\n```"): + return s[4:-4] + elif s.startswith("```") and s.endswith("```"): + return s[3:-3] + elif s.startswith("`") and s.endswith("`"): + return s[1:-1] + return s + + def remove_quotes_and_escapes(output: str) -> str: """ Clean up the output of the completion API, removing unnecessary escapes and quotes diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index 3b056a2f..27244c4b 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -31,6 +31,9 @@ class RangeInFile(BaseModel): range = Range.from_entire_file(content) return RangeInFile(filepath=filepath, range=range) + def translated(self, lines: int): + return RangeInFile(filepath=self.filepath, range=self.range.translated(lines)) + class RangeInFileWithContents(RangeInFile): """A range in a file with the contents of the range.""" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 880fbfef..34c557e0 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -105,6 +105,17 @@ class Range(BaseModel): end=Position(line=self.end.line + 1, character=0), ) + def translated(self, lines: int): + return Range( + start=Position( + line=self.start.line + lines, character=self.start.character + ), + end=Position(line=self.end.line + lines, character=self.end.character), + ) + + def contains(self, position: Position) -> bool: + return self.start <= position and position <= self.end + @staticmethod def from_indices(string: str, start_index: int, end_index: int) -> "Range": return Range( diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index c4a61193..f4fbaf03 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -14,8 +14,7 @@ MAX_SIZE_IN_CHARS = 50_000 async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str: try: return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS] - except Exception as e: - print(f"Failed to read file: {e}") + except Exception as _: return None diff --git a/continuedev/src/continuedev/plugins/steps/refactor.py b/continuedev/src/continuedev/plugins/steps/refactor.py new file mode 100644 index 00000000..cfbce662 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/refactor.py @@ -0,0 +1,106 @@ +import asyncio +from typing import List + +from ...core.main import Step +from ...core.models import Models +from ...core.sdk import ContinueSDK +from ...libs.llm.prompts.edit import simplified_edit_prompt +from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block +from ...libs.util.templating import render_prompt_template +from ...models.filesystem import RangeInFile +from ...models.filesystem_edit import FileEdit +from ...models.main import PositionInFile, Range + + +class RefactorReferencesStep(Step): + name: str = "Refactor references of a symbol" + user_input: str + symbol_location: PositionInFile + + async def describe(self, models: Models): + return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`" + + async def run(self, sdk: ContinueSDK): + while sdk.lsp is None or not sdk.lsp.ready: + await asyncio.sleep(0.1) + + references = await sdk.lsp.find_references( + self.symbol_location.position, self.symbol_location.filepath, False + ) + await sdk.run_step( + ParallelEditStep(user_input=self.user_input, range_in_files=references) + ) + + +class ParallelEditStep(Step): + name: str = "Edit multiple ranges in parallel" + user_input: str + range_in_files: List[RangeInFile] + + hide: bool = True + + async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile): + # TODO: Can use folding info to get a more intuitively shaped range + expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file) + if ( + expanded_range is None + or expanded_range.range.start.line != range_in_file.range.start.line + ): + expanded_range = Range.from_shorthand( + range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0 + ) + else: + expanded_range = expanded_range.range + + new_rif = RangeInFile( + filepath=range_in_file.filepath, + range=expanded_range, + ) + code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif) + + # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit) + + prompt = render_prompt_template( + simplified_edit_prompt, + history=[], + other_data={ + "code_to_edit": code_to_edit, + "user_input": self.user_input, + }, + ) + print(prompt + "\n\n-------------------\n\n") + + new_code = await sdk.models.edit.complete(prompt=prompt) + new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n" + # new_code = ( + # "\n".join([common_whitespace + line for line in new_code.split("\n")]) + # + "\n" + # ) + + print(new_code + "\n\n-------------------\n\n") + + await sdk.ide.applyFileSystemEdit( + FileEdit( + filepath=range_in_file.filepath, + range=expanded_range, + replacement=new_code, + ) + ) + + async def edit_file(self, sdk: ContinueSDK, filepath: str): + ranges_in_file = [ + range_in_file + for range_in_file in self.range_in_files + if range_in_file.filepath == filepath + ] + # Sort in reverse order so that we don't mess up the ranges + ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True) + for i in range(len(ranges_in_file)): + await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i]) + + async def run(self, sdk: ContinueSDK): + tasks = [] + for filepath in set([rif.filepath for rif in self.range_in_files]): + tasks.append(self.edit_file(sdk=sdk, filepath=filepath)) + + await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py index d0058ffc..58d56703 100644 --- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -12,5 +12,8 @@ class StepsOnStartupStep(Step): steps_on_startup = sdk.config.steps_on_startup for step_type in steps_on_startup: - step = step_type() + if isinstance(step_type, Step): + step = step_type + else: + step = step_type() await sdk.run_step(step) diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index f3734470..40d46b18 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -77,14 +77,13 @@ async def ensure_meilisearch_installed() -> bool: except: pass existing_paths.remove(meilisearchPath) - + await download_meilisearch() # Clear the existing directories for p in existing_paths: shutil.rmtree(p, ignore_errors=True) - return False return True @@ -160,17 +159,25 @@ def stop_meilisearch(): import psutil + def kill_proc(port): for proc in psutil.process_iter(): try: - for conns in proc.connections(kind='inet'): + for conns in proc.connections(kind="inet"): if conns.laddr.port == port: - proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL + proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL except psutil.AccessDenied: - logger.warning(f"Failed to kill process on port {port}") + logger.warning(f"Failed to kill process on port {port} (access denied)") + return + except psutil.ZombieProcess: + logger.warning(f"Failed to kill process on port {port} (zombie process)") + return + except psutil.NoSuchProcess: + logger.warning(f"Failed to kill process on port {port} (no such process)") + return async def restart_meilisearch(): stop_meilisearch() kill_proc(7700) - await start_meilisearch()
\ No newline at end of file + await start_meilisearch() |