summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-16 22:08:01 -0700
committerGitHub <noreply@github.com>2023-09-16 22:08:01 -0700
commit7a86f6a41b16d94f676bf327d35fb768854becb4 (patch)
treed28d3fd8cc994452447ef19d23e5167ffc2c12c5 /continuedev
parentdfbae3f6add30b47d2bd0ba34be89af60d9ab660 (diff)
downloadsncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.tar.gz
sncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.tar.bz2
sncontinue-7a86f6a41b16d94f676bf327d35fb768854becb4.zip
Refactor helper (#481)
* feat: :sparkles: add stop_tokens option to LLM * work on refactoring in headless mode * feat: :sparkles: headless mode refactors * chore: :fire: remove test.py
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/__main__.py31
-rw-r--r--continuedev/src/continuedev/core/lsp.py137
-rw-r--r--continuedev/src/continuedev/headless/__init__.py19
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/edit.py8
-rw-r--r--continuedev/src/continuedev/libs/util/strings.py13
-rw-r--r--continuedev/src/continuedev/models/filesystem.py3
-rw-r--r--continuedev/src/continuedev/models/main.py11
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py3
-rw-r--r--continuedev/src/continuedev/plugins/steps/refactor.py106
-rw-r--r--continuedev/src/continuedev/plugins/steps/steps_on_startup.py5
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py19
12 files changed, 303 insertions, 62 deletions
diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py
index 6dd03046..efd62d64 100644
--- a/continuedev/src/continuedev/__main__.py
+++ b/continuedev/src/continuedev/__main__.py
@@ -1,16 +1,31 @@
-import argparse
+import asyncio
+from typing import Optional
+import typer
+
+from .headless import start_headless_session
from .server.main import run_server
+app = typer.Typer()
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--port", help="server port", type=int, default=65432)
- parser.add_argument("--host", help="server host", type=str, default="127.0.0.1")
- args = parser.parse_args()
- run_server(port=args.port, host=args.host)
+@app.command()
+def main(
+ port: int = typer.Option(65432, help="server port"),
+ host: str = typer.Option("127.0.0.1", help="server host"),
+ config: Optional[str] = typer.Option(
+ None, help="The path to the configuration file"
+ ),
+ headless: bool = typer.Option(False, help="Run in headless mode"),
+):
+ if headless:
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(start_headless_session(config=config))
+ tasks = asyncio.all_tasks(loop)
+ loop.run_until_complete(asyncio.gather(*tasks))
+ else:
+ run_server(port=port, host=host)
if __name__ == "__main__":
- main()
+ app()
diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py
index 181eea2e..0c906b22 100644
--- a/continuedev/src/continuedev/core/lsp.py
+++ b/continuedev/src/continuedev/core/lsp.py
@@ -1,6 +1,6 @@
import asyncio
import threading
-from typing import List, Optional
+from typing import List, Literal, Optional
import aiohttp
from pydantic import BaseModel
@@ -9,6 +9,7 @@ from pylsp.python_lsp import PythonLSPServer, start_ws_lang_server
from ..libs.util.logging import logger
from ..models.filesystem import RangeInFile
from ..models.main import Position, Range
+from ..server.meilisearch_server import kill_proc
def filepath_to_uri(filename: str) -> str:
@@ -17,7 +18,7 @@ def filepath_to_uri(filename: str) -> str:
def uri_to_filepath(uri: str) -> str:
if uri.startswith("file://"):
- return uri.lstrip("file://")
+ return uri[7:]
else:
return uri
@@ -26,6 +27,9 @@ PORT = 8099
class LSPClient:
+ ready: bool = False
+ lock: asyncio.Lock = asyncio.Lock()
+
def __init__(self, host: str, port: int, workspace_paths: List[str]):
self.host = host
self.port = port
@@ -37,12 +41,18 @@ class LSPClient:
print("Connecting")
self.ws = await self.session.ws_connect(f"ws://{self.host}:{self.port}/")
print("Connected")
+ self.ready = True
async def send(self, data):
await self.ws.send_json(data)
async def recv(self):
- return await self.ws.receive_json()
+ await self.lock.acquire()
+
+ try:
+ return await self.ws.receive_json()
+ finally:
+ self.lock.release()
async def close(self):
await self.ws.close()
@@ -237,9 +247,27 @@ class LSPClient:
textDocument={"uri": filepath_to_uri(filepath)},
)
+ async def find_references(
+ self, filepath: str, position: Position, include_declaration: bool = False
+ ):
+ return await self.call_method(
+ "textDocument/references",
+ textDocument={"uri": filepath_to_uri(filepath)},
+ position=position.dict(),
+ context={"includeDeclaration": include_declaration},
+ )
+
+ async def folding_range(self, filepath: str):
+ response = await self.call_method(
+ "textDocument/foldingRange",
+ textDocument={"uri": filepath_to_uri(filepath)},
+ )
+ return response["result"]
+
async def start_language_server() -> threading.Thread:
try:
+ kill_proc(PORT)
thread = threading.Thread(
target=start_ws_lang_server,
args=(PORT, False, PythonLSPServer),
@@ -262,12 +290,23 @@ class DocumentSymbol(BaseModel):
location: RangeInFile
+class FoldingRange(BaseModel):
+ range: Range
+ kind: Optional[Literal["comment", "imports", "region"]] = None
+
+
class ContinueLSPClient(BaseModel):
workspace_dir: str
lsp_client: LSPClient = None
lsp_thread: Optional[threading.Thread] = None
+ @property
+ def ready(self):
+ if self.lsp_client is None:
+ return False
+ return self.lsp_client.ready
+
class Config:
arbitrary_types_allowed = True
@@ -287,6 +326,17 @@ class ContinueLSPClient(BaseModel):
if self.lsp_thread:
self.lsp_thread.join()
+ def location_to_range_in_file(self, location):
+ return RangeInFile(
+ filepath=uri_to_filepath(location["uri"]),
+ range=Range.from_shorthand(
+ location["range"]["start"]["line"],
+ location["range"]["start"]["character"],
+ location["range"]["end"]["line"],
+ location["range"]["end"]["character"],
+ ),
+ )
+
async def goto_definition(
self, position: Position, filename: str
) -> List[RangeInFile]:
@@ -294,18 +344,17 @@ class ContinueLSPClient(BaseModel):
filename,
position,
)
- return [
- RangeInFile(
- filepath=uri_to_filepath(x.uri),
- range=Range.from_shorthand(
- x.range.start.line,
- x.range.start.character,
- x.range.end.line,
- x.range.end.character,
- ),
- )
- for x in response
- ]
+ return [self.location_to_range_in_file(x) for x in response]
+
+ async def find_references(
+ self, position: Position, filename: str, include_declaration: bool = False
+ ) -> List[RangeInFile]:
+ response = await self.lsp_client.find_references(
+ filename,
+ position,
+ include_declaration=include_declaration,
+ )
+ return [self.location_to_range_in_file(x) for x in response["result"]]
async def document_symbol(self, filepath: str) -> List:
response = await self.lsp_client.document_symbol(filepath)
@@ -314,15 +363,55 @@ class ContinueLSPClient(BaseModel):
name=x["name"],
containerName=x["containerName"],
kind=x["kind"],
- location=RangeInFile(
- filepath=uri_to_filepath(x["location"]["uri"]),
- range=Range.from_shorthand(
- x["location"]["range"]["start"]["line"],
- x["location"]["range"]["start"]["character"],
- x["location"]["range"]["end"]["line"],
- x["location"]["range"]["end"]["character"],
- ),
- ),
+ location=self.location_to_range_in_file(x["location"]),
)
for x in response["result"]
]
+
+ async def folding_range(self, filepath: str) -> List[FoldingRange]:
+ response = await self.lsp_client.folding_range(filepath)
+
+ return [
+ FoldingRange(
+ range=Range.from_shorthand(
+ x["startLine"],
+ x.get("startCharacter", 0),
+ x["endLine"] if "endCharacter" in x else x["endLine"] + 1,
+ x.get("endCharacter", 0),
+ ),
+ kind=x.get("kind"),
+ )
+ for x in response
+ ]
+
+ async def get_enclosing_folding_range_of_position(
+ self, position: Position, filepath: str
+ ) -> Optional[FoldingRange]:
+ ranges = await self.folding_range(filepath)
+
+ max_start_position = Position(line=0, character=0)
+ max_range = None
+ for r in ranges:
+ if r.range.contains(position):
+ if r.range.start > max_start_position:
+ max_start_position = r.range.start
+ max_range = r
+
+ return max_range
+
+ async def get_enclosing_folding_range(
+ self, range_in_file: RangeInFile
+ ) -> Optional[FoldingRange]:
+ ranges = await self.folding_range(range_in_file.filepath)
+
+ max_start_position = Position(line=0, character=0)
+ max_range = None
+ for r in ranges:
+ if r.range.contains(range_in_file.range.start) and r.range.contains(
+ range_in_file.range.end
+ ):
+ if r.range.start > max_start_position:
+ max_start_position = r.range.start
+ max_range = r
+
+ return max_range
diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py
index 4e46409a..27722ee7 100644
--- a/continuedev/src/continuedev/headless/__init__.py
+++ b/continuedev/src/continuedev/headless/__init__.py
@@ -4,6 +4,7 @@ from typing import Optional, Union
import typer
from ..core.config import ContinueConfig
+from ..core.main import Step
from ..server.session_manager import Session, session_manager
from .headless_ide import LocalIdeProtocol
@@ -21,21 +22,11 @@ async def start_headless_session(
return await session_manager.new_session(ide, config=config)
-async def async_main(config: Optional[str] = None):
- await start_headless_session(config=config)
+def run_step_headless(step: Step):
+ config = ContinueConfig()
+ config.steps_on_startup = [step]
-
-@app.command()
-def main(
- config: Optional[str] = typer.Option(
- None, help="The path to the configuration file"
- )
-):
loop = asyncio.get_event_loop()
- loop.run_until_complete(async_main(config))
+ loop.run_until_complete(start_headless_session(config=config))
tasks = asyncio.all_tasks(loop)
loop.run_until_complete(asyncio.gather(*tasks))
-
-
-if __name__ == "__main__":
- app()
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index baeb9d1a..b2eecab6 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -68,6 +68,10 @@ class LLM(ContinueBaseModel):
..., description="The name of the model to be used (e.g. gpt-4, codellama)"
)
+ stop_tokens: Optional[List[str]] = Field(
+ None, description="Tokens that will stop the completion."
+ )
+
timeout: Optional[int] = Field(
300,
description="Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts.",
@@ -204,7 +208,7 @@ class LLM(ContinueBaseModel):
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
- stop=stop,
+ stop=stop or self.stop_tokens,
max_tokens=max_tokens,
functions=functions,
)
@@ -251,7 +255,7 @@ class LLM(ContinueBaseModel):
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
- stop=stop,
+ stop=stop or self.stop_tokens,
max_tokens=max_tokens,
functions=functions,
)
@@ -296,7 +300,7 @@ class LLM(ContinueBaseModel):
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
- stop=stop,
+ stop=stop or self.stop_tokens,
max_tokens=max_tokens,
functions=functions,
)
diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py
index 7da5a192..eaa694c5 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/edit.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py
@@ -4,10 +4,10 @@ simplified_edit_prompt = dedent(
"""\
Consider the following code:
```
- {{code_to_edit}}
+ {{{code_to_edit}}}
```
Edit the code to perfectly satisfy the following user request:
- {{user_input}}
+ {{{user_input}}}
Output nothing except for the code. No code block, no English explanation, no start/end tags."""
)
@@ -15,11 +15,11 @@ simplest_edit_prompt = dedent(
"""\
Here is the code before editing:
```
- {{code_to_edit}}
+ {{{code_to_edit}}}
```
Here is the edit requested:
- "{{user_input}}"
+ "{{{user_input}}}"
Here is the code after editing:"""
)
diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py
index d33c46c4..f2b6035f 100644
--- a/continuedev/src/continuedev/libs/util/strings.py
+++ b/continuedev/src/continuedev/libs/util/strings.py
@@ -25,6 +25,19 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:
return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp
+def strip_code_block(s: str) -> str:
+ """
+ Strips the code block from a string, if it has one.
+ """
+ if s.startswith("```\n") and s.endswith("\n```"):
+ return s[4:-4]
+ elif s.startswith("```") and s.endswith("```"):
+ return s[3:-3]
+ elif s.startswith("`") and s.endswith("`"):
+ return s[1:-1]
+ return s
+
+
def remove_quotes_and_escapes(output: str) -> str:
"""
Clean up the output of the completion API, removing unnecessary escapes and quotes
diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py
index 3b056a2f..27244c4b 100644
--- a/continuedev/src/continuedev/models/filesystem.py
+++ b/continuedev/src/continuedev/models/filesystem.py
@@ -31,6 +31,9 @@ class RangeInFile(BaseModel):
range = Range.from_entire_file(content)
return RangeInFile(filepath=filepath, range=range)
+ def translated(self, lines: int):
+ return RangeInFile(filepath=self.filepath, range=self.range.translated(lines))
+
class RangeInFileWithContents(RangeInFile):
"""A range in a file with the contents of the range."""
diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py
index 880fbfef..34c557e0 100644
--- a/continuedev/src/continuedev/models/main.py
+++ b/continuedev/src/continuedev/models/main.py
@@ -105,6 +105,17 @@ class Range(BaseModel):
end=Position(line=self.end.line + 1, character=0),
)
+ def translated(self, lines: int):
+ return Range(
+ start=Position(
+ line=self.start.line + lines, character=self.start.character
+ ),
+ end=Position(line=self.end.line + lines, character=self.end.character),
+ )
+
+ def contains(self, position: Position) -> bool:
+ return self.start <= position and position <= self.end
+
@staticmethod
def from_indices(string: str, start_index: int, end_index: int) -> "Range":
return Range(
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index c4a61193..f4fbaf03 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -14,8 +14,7 @@ MAX_SIZE_IN_CHARS = 50_000
async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str:
try:
return (await sdk.ide.readFile(filepath))[:MAX_SIZE_IN_CHARS]
- except Exception as e:
- print(f"Failed to read file: {e}")
+ except Exception as _:
return None
diff --git a/continuedev/src/continuedev/plugins/steps/refactor.py b/continuedev/src/continuedev/plugins/steps/refactor.py
new file mode 100644
index 00000000..cfbce662
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/steps/refactor.py
@@ -0,0 +1,106 @@
+import asyncio
+from typing import List
+
+from ...core.main import Step
+from ...core.models import Models
+from ...core.sdk import ContinueSDK
+from ...libs.llm.prompts.edit import simplified_edit_prompt
+from ...libs.util.strings import remove_quotes_and_escapes, strip_code_block
+from ...libs.util.templating import render_prompt_template
+from ...models.filesystem import RangeInFile
+from ...models.filesystem_edit import FileEdit
+from ...models.main import PositionInFile, Range
+
+
+class RefactorReferencesStep(Step):
+ name: str = "Refactor references of a symbol"
+ user_input: str
+ symbol_location: PositionInFile
+
+ async def describe(self, models: Models):
+ return f"Renamed all instances of `{self.function_name}` to `{self.new_function_name}` in `{self.filepath}`"
+
+ async def run(self, sdk: ContinueSDK):
+ while sdk.lsp is None or not sdk.lsp.ready:
+ await asyncio.sleep(0.1)
+
+ references = await sdk.lsp.find_references(
+ self.symbol_location.position, self.symbol_location.filepath, False
+ )
+ await sdk.run_step(
+ ParallelEditStep(user_input=self.user_input, range_in_files=references)
+ )
+
+
+class ParallelEditStep(Step):
+ name: str = "Edit multiple ranges in parallel"
+ user_input: str
+ range_in_files: List[RangeInFile]
+
+ hide: bool = True
+
+ async def single_edit(self, sdk: ContinueSDK, range_in_file: RangeInFile):
+ # TODO: Can use folding info to get a more intuitively shaped range
+ expanded_range = await sdk.lsp.get_enclosing_folding_range(range_in_file)
+ if (
+ expanded_range is None
+ or expanded_range.range.start.line != range_in_file.range.start.line
+ ):
+ expanded_range = Range.from_shorthand(
+ range_in_file.range.start.line, 0, range_in_file.range.end.line + 1, 0
+ )
+ else:
+ expanded_range = expanded_range.range
+
+ new_rif = RangeInFile(
+ filepath=range_in_file.filepath,
+ range=expanded_range,
+ )
+ code_to_edit = await sdk.ide.readRangeInFile(range_in_file=new_rif)
+
+ # code_to_edit, common_whitespace = dedent_and_get_common_whitespace(code_to_edit)
+
+ prompt = render_prompt_template(
+ simplified_edit_prompt,
+ history=[],
+ other_data={
+ "code_to_edit": code_to_edit,
+ "user_input": self.user_input,
+ },
+ )
+ print(prompt + "\n\n-------------------\n\n")
+
+ new_code = await sdk.models.edit.complete(prompt=prompt)
+ new_code = strip_code_block(remove_quotes_and_escapes(new_code)) + "\n"
+ # new_code = (
+ # "\n".join([common_whitespace + line for line in new_code.split("\n")])
+ # + "\n"
+ # )
+
+ print(new_code + "\n\n-------------------\n\n")
+
+ await sdk.ide.applyFileSystemEdit(
+ FileEdit(
+ filepath=range_in_file.filepath,
+ range=expanded_range,
+ replacement=new_code,
+ )
+ )
+
+ async def edit_file(self, sdk: ContinueSDK, filepath: str):
+ ranges_in_file = [
+ range_in_file
+ for range_in_file in self.range_in_files
+ if range_in_file.filepath == filepath
+ ]
+ # Sort in reverse order so that we don't mess up the ranges
+ ranges_in_file.sort(key=lambda x: x.range.start.line, reverse=True)
+ for i in range(len(ranges_in_file)):
+ await self.single_edit(sdk=sdk, range_in_file=ranges_in_file[i])
+
+ async def run(self, sdk: ContinueSDK):
+ tasks = []
+ for filepath in set([rif.filepath for rif in self.range_in_files]):
+ tasks.append(self.edit_file(sdk=sdk, filepath=filepath))
+
+ await asyncio.gather(*tasks)
diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py
index d0058ffc..58d56703 100644
--- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py
+++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py
@@ -12,5 +12,8 @@ class StepsOnStartupStep(Step):
steps_on_startup = sdk.config.steps_on_startup
for step_type in steps_on_startup:
- step = step_type()
+ if isinstance(step_type, Step):
+ step = step_type
+ else:
+ step = step_type()
await sdk.run_step(step)
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index f3734470..40d46b18 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -77,14 +77,13 @@ async def ensure_meilisearch_installed() -> bool:
except:
pass
existing_paths.remove(meilisearchPath)
-
+
await download_meilisearch()
# Clear the existing directories
for p in existing_paths:
shutil.rmtree(p, ignore_errors=True)
-
return False
return True
@@ -160,17 +159,25 @@ def stop_meilisearch():
import psutil
+
def kill_proc(port):
for proc in psutil.process_iter():
try:
- for conns in proc.connections(kind='inet'):
+ for conns in proc.connections(kind="inet"):
if conns.laddr.port == port:
- proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL
+ proc.send_signal(psutil.signal.SIGTERM) # or SIGKILL
except psutil.AccessDenied:
- logger.warning(f"Failed to kill process on port {port}")
+ logger.warning(f"Failed to kill process on port {port} (access denied)")
+ return
+ except psutil.ZombieProcess:
+ logger.warning(f"Failed to kill process on port {port} (zombie process)")
+ return
+ except psutil.NoSuchProcess:
+ logger.warning(f"Failed to kill process on port {port} (no such process)")
+ return
async def restart_meilisearch():
stop_meilisearch()
kill_proc(7700)
- await start_meilisearch() \ No newline at end of file
+ await start_meilisearch()