diff options
32 files changed, 1750 insertions, 136 deletions
diff --git a/.vscode/launch.json b/.vscode/launch.json index 08061d13..9ccf4ce7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,6 +37,17 @@ // What about a watch task? - type errors? }, { + "name": "Headless", + "type": "python", + "request": "launch", + "module": "continuedev.headless", + "args": ["--config", "continuedev/config.py"], + "justMyCode": false, + "subProcess": false + // Does it need a build task? + // What about a watch task? - type errors? + }, + { "name": "Extension (VSCode)", "type": "extensionHost", "request": "launch", diff --git a/continuedev/requirements.txt b/continuedev/requirements.txt index 0cf909d5..91120c59 100644 --- a/continuedev/requirements.txt +++ b/continuedev/requirements.txt @@ -22,4 +22,5 @@ socksio==1.0.0 ripgrepy==2.0.0 replicate==0.11.0 bs4==0.0.1 -redbaron==0.9.2
\ No newline at end of file +redbaron==0.9.2 +python-lsp-server==1.2.0
\ No newline at end of file diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index bae82739..8ac7241d 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -17,7 +17,10 @@ from ..libs.util.paths import getSavedContextGroupsPath from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.strings import remove_quotes_and_escapes from ..libs.util.telemetry import posthog_logger -from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback +from ..libs.util.traceback.traceback_parsers import ( + get_javascript_traceback, + get_python_traceback, +) from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from ..models.main import ContinueBaseModel @@ -32,6 +35,8 @@ from ..plugins.steps.core.core import ( ) from ..plugins.steps.on_traceback import DefaultOnTracebackStep from ..server.ide_protocol import AbstractIdeProtocolServer +from ..server.meilisearch_server import stop_meilisearch +from .config import ContinueConfig from .context import ContextManager from .main import ( Context, @@ -97,8 +102,12 @@ class Autopilot(ContinueBaseModel): started: bool = False - async def start(self, full_state: Optional[FullState] = None): - self.continue_sdk = await ContinueSDK.create(self) + async def start( + self, + full_state: Optional[FullState] = None, + config: Optional[ContinueConfig] = None, + ): + self.continue_sdk = await ContinueSDK.create(self, config=config) if override_policy := self.continue_sdk.config.policy_override: self.policy = override_policy @@ -134,6 +143,11 @@ class Autopilot(ContinueBaseModel): self.started = True + async def cleanup(self): + if self.continue_sdk.lsp is not None: + await self.continue_sdk.lsp.stop() + stop_meilisearch() + class Config: arbitrary_types_allowed = True keep_untouched = (cached_property,) diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 62e9c690..68b2b17d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -59,3 +59,14 @@ class ContinueConfig(BaseModel): @validator("temperature", pre=True) def temperature_validator(cls, v): return max(0.0, min(1.0, v)) + + @staticmethod + def from_filepath(filepath: str) -> "ContinueConfig": + # Use importlib to load the config file config.py at the given path + import importlib.util + + spec = importlib.util.spec_from_file_location("config", filepath) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + return config.config diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py new file mode 100644 index 00000000..5c1f9989 --- /dev/null +++ b/continuedev/src/continuedev/core/lsp.py @@ -0,0 +1,310 @@ +import os +import socket +import subprocess +import threading +from typing import List, Optional + +from pydantic import BaseModel + +from ..libs.lspclient.json_rpc_endpoint import JsonRpcEndpoint +from ..libs.lspclient.lsp_client import LspClient +from ..libs.lspclient.lsp_endpoint import LspEndpoint +from ..libs.lspclient.lsp_structs import Position as LspPosition +from ..libs.lspclient.lsp_structs import SymbolInformation, TextDocumentIdentifier +from ..models.filesystem import RangeInFile +from ..models.main import Position, Range + + +class ReadPipe(threading.Thread): + def __init__(self, pipe): + threading.Thread.__init__(self) + self.pipe = pipe + + def run(self): + line = self.pipe.readline().decode("utf-8") + while line: + print(line) + line = self.pipe.readline().decode("utf-8") + + +class SocketFileWrapper: + def __init__(self, sockfile): + self.sockfile = sockfile + + def write(self, data): + if isinstance(data, bytes): + data = data.decode("utf-8").replace("\r\n", "\n") + return self.sockfile.write(data) + + def read(self, size=-1): + data = self.sockfile.read(size) + if isinstance(data, str): + data = data.replace("\n", "\r\n").encode("utf-8") + return data + + def readline(self, size=-1): + data = self.sockfile.readline(size) + if isinstance(data, str): + data = data.replace("\n", "\r\n").encode("utf-8") + return data + + def flush(self): + return self.sockfile.flush() + + def close(self): + return self.sockfile.close() + + +def create_json_rpc_endpoint(use_subprocess: Optional[str] = None): + if use_subprocess is None: + # Connect to the server + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(("localhost", 8080)) + + # Create a file-like object from the socket + sockfile = s.makefile("rw") + wrapped_sockfile = SocketFileWrapper(sockfile) + return JsonRpcEndpoint(wrapped_sockfile, wrapped_sockfile), None + + else: + pyls_cmd = use_subprocess.split() + p = subprocess.Popen( + pyls_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + read_pipe = ReadPipe(p.stderr) + read_pipe.start() + return JsonRpcEndpoint(p.stdin, p.stdout), p + + +def filename_to_uri(filename: str) -> str: + return f"file://{filename}" + + +def uri_to_filename(uri: str) -> str: + if uri.startswith("file://"): + return uri.lstrip("file://") + else: + return uri + + +def create_lsp_client(workspace_dir: str, use_subprocess: Optional[str] = None): + json_rpc_endpoint, process = create_json_rpc_endpoint(use_subprocess=use_subprocess) + lsp_endpoint = LspEndpoint(json_rpc_endpoint) + lsp_client = LspClient(lsp_endpoint) + capabilities = { + "textDocument": { + "codeAction": {"dynamicRegistration": True}, + "codeLens": {"dynamicRegistration": True}, + "colorProvider": {"dynamicRegistration": True}, + "completion": { + "completionItem": { + "commitCharactersSupport": True, + "documentationFormat": ["markdown", "plaintext"], + "snippetSupport": True, + }, + "completionItemKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + ] + }, + "contextSupport": True, + "dynamicRegistration": True, + }, + "definition": {"dynamicRegistration": True}, + "documentHighlight": {"dynamicRegistration": True}, + "documentLink": {"dynamicRegistration": True}, + "documentSymbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + }, + "formatting": {"dynamicRegistration": True}, + "hover": { + "contentFormat": ["markdown", "plaintext"], + "dynamicRegistration": True, + }, + "implementation": {"dynamicRegistration": True}, + "onTypeFormatting": {"dynamicRegistration": True}, + "publishDiagnostics": {"relatedInformation": True}, + "rangeFormatting": {"dynamicRegistration": True}, + "references": {"dynamicRegistration": True}, + "rename": {"dynamicRegistration": True}, + "signatureHelp": { + "dynamicRegistration": True, + "signatureInformation": { + "documentationFormat": ["markdown", "plaintext"] + }, + }, + "synchronization": { + "didSave": True, + "dynamicRegistration": True, + "willSave": True, + "willSaveWaitUntil": True, + }, + "typeDefinition": {"dynamicRegistration": True}, + }, + "workspace": { + "applyEdit": True, + "configuration": True, + "didChangeConfiguration": {"dynamicRegistration": True}, + "didChangeWatchedFiles": {"dynamicRegistration": True}, + "executeCommand": {"dynamicRegistration": True}, + "symbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + }, + "workspaceEdit": {"documentChanges": True}, + "workspaceFolders": True, + }, + } + root_uri = filename_to_uri(workspace_dir) + dir_name = os.path.basename(workspace_dir) + workspace_folders = [{"name": dir_name, "uri": root_uri}] + lsp_client.initialize( + None, + None, + root_uri, + None, + capabilities, + "off", + workspace_folders, + ) + lsp_client.initialized() + return lsp_client, process + + +class ContinueLSPClient(BaseModel): + workspace_dir: str + lsp_client: LspClient = None + use_subprocess: Optional[str] = None + lsp_process: Optional[subprocess.Popen] = None + + class Config: + arbitrary_types_allowed = True + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("lsp_client", None) + return original_dict + + async def start(self): + self.lsp_client, self.lsp_process = create_lsp_client( + self.workspace_dir, use_subprocess=self.use_subprocess + ) + + async def stop(self): + self.lsp_client.shutdown() + self.lsp_client.exit() + if self.lsp_process is not None: + self.lsp_process.terminate() + self.lsp_process.wait() + self.lsp_process = None + + def goto_definition(self, position: Position, filename: str) -> List[RangeInFile]: + response = self.lsp_client.definition( + TextDocumentIdentifier(filename_to_uri(filename)), + LspPosition(position.line, position.character), + ) + return [ + RangeInFile( + filepath=uri_to_filename(x.uri), + range=Range.from_shorthand( + x.range.start.line, + x.range.start.character, + x.range.end.line, + x.range.end.character, + ), + ) + for x in response + ] + + def get_symbols(self, filepath: str) -> List[SymbolInformation]: + response = self.lsp_client.documentSymbol( + TextDocumentIdentifier(filename_to_uri(filepath)) + ) + + return response diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 9ff6612c..7dca600d 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,8 +1,9 @@ import os import traceback -from typing import Coroutine, Union +from typing import Coroutine, List, Optional, Union from ..libs.llm import LLM +from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger from ..libs.util.paths import getConfigFilePath from ..libs.util.telemetry import posthog_logger @@ -16,11 +17,18 @@ from ..models.filesystem_edit import ( FileSystemEdit, ) from ..models.main import Range -from ..plugins.steps.core.core import * -from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..plugins.steps.core.core import ( + DefaultModelEditCodeStep, + FileSystemEditStep, + MessageStep, + RangeInFileWithContents, + ShellCommandsStep, + WaitForUserConfirmationStep, +) from ..server.ide_protocol import AbstractIdeProtocolServer from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig +from .lsp import ContinueLSPClient from .main import ( ChatMessage, Context, @@ -42,6 +50,7 @@ class ContinueSDK(AbstractContinueSDK): ide: AbstractIdeProtocolServer models: Models + lsp: Optional[ContinueLSPClient] = None context: Context config: ContinueConfig __autopilot: Autopilot @@ -52,13 +61,14 @@ class ContinueSDK(AbstractContinueSDK): self.context = autopilot.context @classmethod - async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + async def create( + cls, autopilot: Autopilot, config: Optional[ContinueConfig] = None + ) -> "ContinueSDK": sdk = ContinueSDK(autopilot) autopilot.continue_sdk = sdk try: - config = sdk._load_config_dot_py() - sdk.config = config + sdk.config = config or sdk._load_config_dot_py() except Exception as e: logger.error(f"Failed to load config.py: {traceback.format_exception(e)}") @@ -78,9 +88,26 @@ class ContinueSDK(AbstractContinueSDK): ) await sdk.ide.setFileOpen(getConfigFilePath()) + # Start models sdk.models = sdk.config.models await sdk.models.start(sdk) + # Start LSP + async def start_lsp(): + try: + sdk.lsp = ContinueLSPClient( + workspace_dir=sdk.ide.workspace_directory, + use_subprocess="python3.10 -m pylsp", + ) + await sdk.lsp.start() + except: + logger.warning("Failed to start LSP client", exc_info=True) + sdk.lsp = None + + create_async_task( + start_lsp(), on_error=lambda e: logger.error("Failed to setup LSP: %s", e) + ) + # When the config is loaded, setup posthog logger posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) @@ -207,19 +234,13 @@ class ContinueSDK(AbstractContinueSDK): _last_valid_config: ContinueConfig = None def _load_config_dot_py(self) -> ContinueConfig: - # Use importlib to load the config file config.py at the given path path = getConfigFilePath() - - import importlib.util - - spec = importlib.util.spec_from_file_location("config", path) - config = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config) - self._last_valid_config = config.config + config = ContinueConfig.from_filepath(path) + self._last_valid_config = config logger.debug("Loaded Continue config file from %s", path) - return config.config + return config def get_code_context( self, only_editing: bool = False diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py new file mode 100644 index 00000000..4e46409a --- /dev/null +++ b/continuedev/src/continuedev/headless/__init__.py @@ -0,0 +1,41 @@ +import asyncio +from typing import Optional, Union + +import typer + +from ..core.config import ContinueConfig +from ..server.session_manager import Session, session_manager +from .headless_ide import LocalIdeProtocol + +app = typer.Typer() + + +async def start_headless_session( + config: Optional[Union[str, ContinueConfig]] = None +) -> Session: + if config is not None: + if isinstance(config, str): + config: ContinueConfig = ContinueConfig.from_filepath(config) + + ide = LocalIdeProtocol() + return await session_manager.new_session(ide, config=config) + + +async def async_main(config: Optional[str] = None): + await start_headless_session(config=config) + + +@app.command() +def main( + config: Optional[str] = typer.Option( + None, help="The path to the configuration file" + ) +): + loop = asyncio.get_event_loop() + loop.run_until_complete(async_main(config)) + tasks = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*tasks)) + + +if __name__ == "__main__": + app() diff --git a/continuedev/src/continuedev/headless/headless_ide.py b/continuedev/src/continuedev/headless/headless_ide.py new file mode 100644 index 00000000..088da2c9 --- /dev/null +++ b/continuedev/src/continuedev/headless/headless_ide.py @@ -0,0 +1,181 @@ +import os +import subprocess +import uuid +from typing import Any, Callable, Coroutine, List, Optional + +from dotenv import load_dotenv +from fastapi import WebSocket + +from ..models.filesystem import ( + FileSystem, + RangeInFile, + RangeInFileWithContents, + RealFileSystem, +) +from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit +from ..server.ide_protocol import AbstractIdeProtocolServer + +load_dotenv() + + +def get_mac_address(): + mac_num = hex(uuid.getnode()).replace("0x", "").upper() + mac = "-".join(mac_num[i : i + 2] for i in range(0, 11, 2)) + return mac + + +class LocalIdeProtocol(AbstractIdeProtocolServer): + websocket: WebSocket = None + session_id: Optional[str] + workspace_directory: str = os.getcwd() + unique_id: str = get_mac_address() + + filesystem: FileSystem = RealFileSystem() + + async def handle_json(self, data: Any): + """Handle a json message""" + pass + + def showSuggestion(self, file_edit: FileEdit): + """Show a suggestion to the user""" + pass + + async def setFileOpen(self, filepath: str, open: bool = True): + """Set whether a file is open""" + pass + + async def showMessage(self, message: str): + """Show a message to the user""" + print(message) + + async def showVirtualFile(self, name: str, contents: str): + """Show a virtual file""" + pass + + async def setSuggestionsLocked(self, filepath: str, locked: bool = True): + """Set whether suggestions are locked""" + pass + + async def getSessionId(self): + """Get a new session ID""" + pass + + async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: + """Show suggestions to the user and wait for a response""" + pass + + def onAcceptRejectSuggestion(self, accepted: bool): + """Called when the user accepts or rejects a suggestion""" + pass + + def onFileSystemUpdate(self, update: FileSystemEdit): + """Called when a file system update is received""" + pass + + def onCloseGUI(self, session_id: str): + """Called when a GUI is closed""" + pass + + def onOpenGUIRequest(self): + """Called when a GUI is requested to be opened""" + pass + + async def getOpenFiles(self) -> List[str]: + """Get a list of open files""" + pass + + async def getVisibleFiles(self) -> List[str]: + """Get a list of visible files""" + pass + + async def getHighlightedCode(self) -> List[RangeInFile]: + """Get a list of highlighted code""" + pass + + async def readFile(self, filepath: str) -> str: + """Read a file""" + return self.filesystem.read(filepath) + + async def readRangeInFile(self, range_in_file: RangeInFile) -> str: + """Read a range in a file""" + return self.filesystem.read_range_in_file(range_in_file) + + async def editFile(self, edit: FileEdit): + """Edit a file""" + self.filesystem.apply_file_edit(edit) + + async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: + """Apply a file edit""" + return self.filesystem.apply_edit(edit) + + async def saveFile(self, filepath: str): + """Save a file""" + pass + + async def getUserSecret(self, key: str): + """Get a user secret""" + return os.environ.get(key) + + async def highlightCode(self, range_in_file: RangeInFile, color: str): + """Highlight code""" + pass + + async def runCommand(self, command: str) -> str: + """Run a command using subprocess (don't pass, actually implement)""" + return subprocess.check_output(command, shell=True).decode("utf-8") + + def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + """Called when highlighted code is updated""" + pass + + def onDeleteAtIndex(self, index: int): + """Called when a step is deleted at a given index""" + pass + + async def showDiff(self, filepath: str, replacement: str, step_index: int): + """Show a diff""" + pass + + def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): + """Subscribe to files created event""" + pass + + def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): + """Subscribe to files deleted event""" + pass + + def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): + """Subscribe to files renamed event""" + pass + + def subscribeToFileSaved(self, callback: Callable[[str, str], None]): + """Subscribe to file saved event""" + pass + + def onFilesCreated(self, filepaths: List[str]): + """Called when files are created""" + pass + + def onFilesDeleted(self, filepaths: List[str]): + """Called when files are deleted""" + pass + + def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): + """Called when files are renamed""" + pass + + def onFileSaved(self, filepath: str, contents: str): + """Called when a file is saved""" + pass + + async def fileExists(self, filepath: str) -> Coroutine[Any, Any, str]: + """Check if a file exists""" + return self.filesystem.exists(filepath) + + async def getTerminalContents(self) -> Coroutine[Any, Any, str]: + return "" + + async def listDirectoryContents( + self, directory: str, recursive: bool = False + ) -> List[str]: + return self.filesystem.list_directory_contents(directory, recursive=recursive) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 03c9cce4..1b91ec43 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -4,6 +4,7 @@ from typing import Callable, Optional import aiohttp from ..llm import LLM +from ..util.logging import logger from .prompts.chat import llama2_template_messages from .prompts.edit import simplified_edit_prompt @@ -59,7 +60,13 @@ class TogetherLLM(LLM): if chunk.strip() != "": if chunk.startswith("data: "): chunk = chunk[6:] - json_chunk = json.loads(chunk) + if chunk == "[DONE]": + break + try: + json_chunk = json.loads(chunk) + except Exception as e: + logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}") + continue if "choices" in json_chunk: yield json_chunk["choices"][0]["text"] diff --git a/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py new file mode 100644 index 00000000..80c51000 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py @@ -0,0 +1,82 @@ +from __future__ import print_function + +import json +import re +import threading + +JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}" +JSON_RPC_RES_REGEX = "Content-Length: ([0-9]*)\r\n" +# TODO: add content-type + + +class MyEncoder(json.JSONEncoder): + """ + Encodes an object in JSON + """ + + def default(self, o): + return o.__dict__ + + +class JsonRpcEndpoint(object): + """ + Thread safe JSON RPC endpoint implementation. Responsible to recieve and send JSON RPC messages, as described in the + protocol. More information can be found: https://www.jsonrpc.org/ + """ + + def __init__(self, stdin, stdout): + self.stdin = stdin + self.stdout = stdout + self.read_lock = threading.Lock() + self.write_lock = threading.Lock() + + @staticmethod + def __add_header(json_string): + """ + Adds a header for the given json string + + :param str json_string: The string + :return: the string with the header + """ + return JSON_RPC_REQ_FORMAT.format( + json_string_len=len(json_string), json_string=json_string + ) + + def send_request(self, message): + """ + Sends the given message. + + :param dict message: The message to send. + """ + json_string = json.dumps(message, cls=MyEncoder) + # print("sending:", json_string) + jsonrpc_req = self.__add_header(json_string) + with self.write_lock: + self.stdin.write(jsonrpc_req.encode()) + self.stdin.flush() + + def recv_response(self): + """ + Recives a message. + + :return: a message + """ + with self.read_lock: + line = self.stdout.readline() + if not line: + return None + # print(line) + line = line.decode() + # TODO: handle content type as well. + match = re.match(JSON_RPC_RES_REGEX, line) + if match is None or not match.groups(): + raise RuntimeError("Bad header: " + line) + size = int(match.groups()[0]) + line = self.stdout.readline() + if not line: + return None + line = line.decode() + # if line != "\r\n": + # raise RuntimeError("Bad header: missing newline") + jsonrpc_res = self.stdout.read(size + 2) + return json.loads(jsonrpc_res) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_client.py b/continuedev/src/continuedev/libs/lspclient/lsp_client.py new file mode 100644 index 00000000..fe2db6ad --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_client.py @@ -0,0 +1,150 @@ +from .lsp_structs import Location, SignatureHelp, SymbolInformation + + +class LspClient(object): + def __init__(self, lsp_endpoint): + """ + Constructs a new LspClient instance. + + :param lsp_endpoint: TODO + """ + self.lsp_endpoint = lsp_endpoint + + def initialize( + self, + processId, + rootPath, + rootUri, + initializationOptions, + capabilities, + trace, + workspaceFolders, + ): + """ + The initialize request is sent as the first request from the client to the server. If the server receives a request or notification + before the initialize request it should act as follows: + + 1. For a request the response should be an error with code: -32002. The message can be picked by the server. + 2. Notifications should be dropped, except for the exit notification. This will allow the exit of a server without an initialize request. + + Until the server has responded to the initialize request with an InitializeResult, the client must not send any additional requests or + notifications to the server. In addition the server is not allowed to send any requests or notifications to the client until it has responded + with an InitializeResult, with the exception that during the initialize request the server is allowed to send the notifications window/showMessage, + window/logMessage and telemetry/event as well as the window/showMessageRequest request to the client. + + The initialize request may only be sent once. + + :param int processId: The process Id of the parent process that started the server. Is null if the process has not been started by another process. + If the parent process is not alive then the server should exit (see exit notification) its process. + :param str rootPath: The rootPath of the workspace. Is null if no folder is open. Deprecated in favour of rootUri. + :param DocumentUri rootUri: The rootUri of the workspace. Is null if no folder is open. If both `rootPath` and `rootUri` are set + `rootUri` wins. + :param any initializationOptions: User provided initialization options. + :param ClientCapabilities capabilities: The capabilities provided by the client (editor or tool). + :param Trace trace: The initial trace setting. If omitted trace is disabled ('off'). + :param list workspaceFolders: The workspace folders configured in the client when the server starts. This property is only available if the client supports workspace folders. + It can be `null` if the client supports workspace folders but none are configured. + """ + self.lsp_endpoint.start() + return self.lsp_endpoint.call_method( + "initialize", + processId=processId, + rootPath=rootPath, + rootUri=rootUri, + initializationOptions=initializationOptions, + capabilities=capabilities, + trace=trace, + workspaceFolders=workspaceFolders, + ) + + def initialized(self): + """ + The initialized notification is sent from the client to the server after the client received the result of the initialize request + but before the client is sending any other request or notification to the server. The server can use the initialized notification + for example to dynamically register capabilities. The initialized notification may only be sent once. + """ + self.lsp_endpoint.send_notification("initialized") + + def shutdown(self): + """ + The initialized notification is sent from the client to the server after the client received the result of the initialize request + but before the client is sending any other request or notification to the server. The server can use the initialized notification + for example to dynamically register capabilities. The initialized notification may only be sent once. + """ + self.lsp_endpoint.stop() + return self.lsp_endpoint.call_method("shutdown") + + def exit(self): + """ + The initialized notification is sent from the client to the server after the client received the result of the initialize request + but before the client is sending any other request or notification to the server. The server can use the initialized notification + for example to dynamically register capabilities. The initialized notification may only be sent once. + """ + self.lsp_endpoint.send_notification("exit") + + def didOpen(self, textDocument): + """ + The document open notification is sent from the client to the server to signal newly opened text documents. The document's truth is + now managed by the client and the server must not try to read the document's truth using the document's uri. Open in this sense + means it is managed by the client. It doesn't necessarily mean that its content is presented in an editor. An open notification must + not be sent more than once without a corresponding close notification send before. This means open and close notification must be + balanced and the max open count for a particular textDocument is one. Note that a server's ability to fulfill requests is independent + of whether a text document is open or closed. + + The DidOpenTextDocumentParams contain the language id the document is associated with. If the language Id of a document changes, the + client needs to send a textDocument/didClose to the server followed by a textDocument/didOpen with the new language id if the server + handles the new language id as well. + + :param TextDocumentItem textDocument: The initial trace setting. If omitted trace is disabled ('off'). + """ + return self.lsp_endpoint.send_notification( + "textDocument/didOpen", textDocument=textDocument + ) + + def documentSymbol(self, textDocument): + """ + The document symbol request is sent from the client to the server to return a flat list of all symbols found in a given text document. + Neither the symbol's location range nor the symbol's container name should be used to infer a hierarchy. + + :param TextDocumentItem textDocument: The text document. + """ + result_dict = self.lsp_endpoint.call_method( + "textDocument/documentSymbol", textDocument=textDocument + ) + return [SymbolInformation(**sym) for sym in result_dict] + + def definition(self, textDocument, position): + """ + The goto definition request is sent from the client to the server to resolve the definition location of a symbol at a given text document position. + + :param TextDocumentItem textDocument: The text document. + :param Position position: The position inside the text document.. + """ + result_dict = self.lsp_endpoint.call_method( + "textDocument/definition", textDocument=textDocument, position=position + ) + return [Location(**l) for l in result_dict] + + def typeDefinition(self, textDocument, position): + """ + The goto type definition request is sent from the client to the server to resolve the type definition location of a symbol at a given text document position. + + :param TextDocumentItem textDocument: The text document. + :param Position position: The position inside the text document.. + """ + result_dict = self.lsp_endpoint.call_method( + "textDocument/definition", textDocument=textDocument, position=position + ) + return [Location(**l) for l in result_dict] + + def signatureHelp(self, textDocument, position): + """ + The signature help request is sent from the client to the server to request signature information at a given cursor position. + + :param TextDocumentItem textDocument: The text document. + :param Position position: The position inside the text document.. + """ + result_dict = self.lsp_endpoint.call_method( + "textDocument/signatureHelp", textDocument=textDocument, position=position + ) + return SignatureHelp(**result_dict) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py new file mode 100644 index 00000000..14d2ca07 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py @@ -0,0 +1,71 @@ +from __future__ import print_function + +import threading + + +class LspEndpoint(threading.Thread): + def __init__(self, json_rpc_endpoint, default_callback=print, callbacks={}): + threading.Thread.__init__(self) + self.json_rpc_endpoint = json_rpc_endpoint + self.callbacks = callbacks + self.default_callback = default_callback + self.event_dict = {} + self.response_dict = {} + self.next_id = 0 + # self.daemon = True + self.shutdown_flag = False + + def handle_result(self, jsonrpc_res): + self.response_dict[jsonrpc_res["id"]] = jsonrpc_res + cond = self.event_dict[jsonrpc_res["id"]] + cond.acquire() + cond.notify() + cond.release() + + def stop(self): + self.shutdown_flag = True + + def run(self): + while not self.shutdown_flag: + jsonrpc_message = self.json_rpc_endpoint.recv_response() + + if jsonrpc_message is None: + print("server quit") + break + + # print("recieved message:", jsonrpc_message) + if "result" in jsonrpc_message or "error" in jsonrpc_message: + self.handle_result(jsonrpc_message) + elif "method" in jsonrpc_message: + if jsonrpc_message["method"] in self.callbacks: + self.callbacks[jsonrpc_message["method"]](jsonrpc_message) + else: + self.default_callback(jsonrpc_message) + else: + print("unknown jsonrpc message") + # print(jsonrpc_message) + + def send_message(self, method_name, params, id=None): + message_dict = {} + message_dict["jsonrpc"] = "2.0" + if id is not None: + message_dict["id"] = id + message_dict["method"] = method_name + message_dict["params"] = params + self.json_rpc_endpoint.send_request(message_dict) + + def call_method(self, method_name, **kwargs): + current_id = self.next_id + self.next_id += 1 + cond = threading.Condition() + self.event_dict[current_id] = cond + cond.acquire() + self.send_message(method_name, kwargs, current_id) + cond.wait() + cond.release() + # TODO: check if error, and throw an exception + response = self.response_dict[current_id] + return response["result"] + + def send_notification(self, method_name, **kwargs): + self.send_message(method_name, kwargs) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_structs.py b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py new file mode 100644 index 00000000..2f0940d4 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py @@ -0,0 +1,316 @@ +def to_type(o, new_type): + ''' + Helper funciton that receives an object or a dict and convert it to a new given type. + + :param object|dict o: The object to convert + :param Type new_type: The type to convert to. + ''' + if new_type == type(o): + return o + else: + return new_type(**o) + + +class Position(object): + def __init__(self, line, character): + """ + Constructs a new Position instance. + + :param int line: Line position in a document (zero-based). + :param int character: Character offset on a line in a document (zero-based). + """ + self.line = line + self.character = character + + +class Range(object): + def __init__(self, start, end): + """ + Constructs a new Range instance. + + :param Position start: The range's start position. + :param Position end: The range's end position. + """ + self.start = to_type(start, Position) + self.end = to_type(end, Position) + + +class Location(object): + """ + Represents a location inside a resource, such as a line inside a text file. + """ + def __init__(self, uri, range): + """ + Constructs a new Range instance. + + :param str uri: Resource file. + :param Range range: The range inside the file + """ + self.uri = uri + self.range = to_type(range, Range) + + +class Diagnostic(object): + def __init__(self, range, severity, code, source, message, relatedInformation): + """ + Constructs a new Diagnostic instance. + :param Range range: The range at which the message applies.Resource file. + :param int severity: The diagnostic's severity. Can be omitted. If omitted it is up to the + client to interpret diagnostics as error, warning, info or hint. + :param str code: The diagnostic's code, which might appear in the user interface. + :param str source: A human-readable string describing the source of this + diagnostic, e.g. 'typescript' or 'super lint'. + :param str message: The diagnostic's message. + :param list relatedInformation: An array of related diagnostic information, e.g. when symbol-names within + a scope collide all definitions can be marked via this property. + """ + self.range = range + self.severity = severity + self.code = code + self.source = source + self.message = message + self.relatedInformation = relatedInformation + + +class DiagnosticSeverity(object): + Error = 1 + Warning = 2 # TODO: warning is known in python + Information = 3 + Hint = 4 + + +class DiagnosticRelatedInformation(object): + def __init__(self, location, message): + """ + Constructs a new Diagnostic instance. + :param Location location: The location of this related diagnostic information. + :param str message: The message of this related diagnostic information. + """ + self.location = location + self.message = message + + +class Command(object): + def __init__(self, title, command, arguments): + """ + Constructs a new Diagnostic instance. + :param str title: Title of the command, like `save`. + :param str command: The identifier of the actual command handler. + :param list argusments: Arguments that the command handler should be invoked with. + """ + self.title = title + self.command = command + self.arguments = arguments + + +class TextDocumentItem(object): + """ + An item to transfer a text document from the client to the server. + """ + def __init__(self, uri, languageId, version, text): + """ + Constructs a new Diagnostic instance. + + :param DocumentUri uri: Title of the command, like `save`. + :param str languageId: The identifier of the actual command handler. + :param int version: Arguments that the command handler should be invoked with. + :param str text: Arguments that the command handler should be invoked with. + """ + self.uri = uri + self.languageId = languageId + self.version = version + self.text = text + + +class TextDocumentIdentifier(object): + """ + Text documents are identified using a URI. On the protocol level, URIs are passed as strings. + """ + def __init__(self, uri): + """ + Constructs a new TextDocumentIdentifier instance. + + :param DocumentUri uri: The text document's URI. + """ + self.uri = uri + +class TextDocumentPositionParams(object): + """ + A parameter literal used in requests to pass a text document and a position inside that document. + """ + def __init__(self, textDocument, position): + """ + Constructs a new TextDocumentPositionParams instance. + + :param TextDocumentIdentifier textDocument: The text document. + :param Position position: The position inside the text document. + """ + self.textDocument = textDocument + self.position = position + + +class LANGUAGE_IDENTIFIER: + BAT="bat" + BIBTEX="bibtex" + CLOJURE="clojure" + COFFESCRIPT="coffeescript" + C="c" + CPP="cpp" + CSHARP="csharp" + CSS="css" + DIFF="diff" + DOCKERFILE="dockerfile" + FSHARP="fsharp" + GIT_COMMIT="git-commit" + GIT_REBASE="git-rebase" + GO="go" + GROOVY="groovy" + HANDLEBARS="handlebars" + HTML="html" + INI="ini" + JAVA="java" + JAVASCRIPT="javascript" + JSON="json" + LATEX="latex" + LESS="less" + LUA="lua" + MAKEFILE="makefile" + MARKDOWN="markdown" + OBJECTIVE_C="objective-c" + OBJECTIVE_CPP="objective-cpp" + Perl="perl" + PHP="php" + POWERSHELL="powershell" + PUG="jade" + PYTHON="python" + R="r" + RAZOR="razor" + RUBY="ruby" + RUST="rust" + SASS="sass" + SCSS="scss" + ShaderLab="shaderlab" + SHELL_SCRIPT="shellscript" + SQL="sql" + SWIFT="swift" + TYPE_SCRIPT="typescript" + TEX="tex" + VB="vb" + XML="xml" + XSL="xsl" + YAML="yaml" + + +class SymbolKind(object): + File = 1 + Module = 2 + Namespace = 3 + Package = 4 + Class = 5 + Method = 6 + Property = 7 + Field = 8 + Constructor = 9 + Enum = 10 + Interface = 11 + Function = 12 + Variable = 13 + Constant = 14 + String = 15 + Number = 16 + Boolean = 17 + Array = 18 + Object = 19 + Key = 20 + Null = 21 + EnumMember = 22 + Struct = 23 + Event = 24 + Operator = 25 + TypeParameter = 26 + + +class SymbolInformation(object): + """ + Represents information about programming constructs like variables, classes, interfaces etc. + """ + def __init__(self, name, kind, location, containerName, deprecated=False): + """ + Constructs a new SymbolInformation instance. + + :param str name: The name of this symbol. + :param int kind: The kind of this symbol. + :param bool Location: The location of this symbol. The location's range is used by a tool + to reveal the location in the editor. If the symbol is selected in the + tool the range's start information is used to position the cursor. So + the range usually spans more then the actual symbol's name and does + normally include things like visibility modifiers. + + The range doesn't have to denote a node range in the sense of a abstract + syntax tree. It can therefore not be used to re-construct a hierarchy of + the symbols. + :param str containerName: The name of the symbol containing this symbol. This information is for + user interface purposes (e.g. to render a qualifier in the user interface + if necessary). It can't be used to re-infer a hierarchy for the document + symbols. + :param bool deprecated: Indicates if this symbol is deprecated. + """ + self.name = name + self.kind = kind + self.deprecated = deprecated + self.location = to_type(location, Location) + self.containerName = containerName + + +class ParameterInformation(object): + """ + Represents a parameter of a callable-signature. A parameter can + have a label and a doc-comment. + """ + def __init__(self, label, documentation=""): + """ + Constructs a new ParameterInformation instance. + + :param str label: The label of this parameter. Will be shown in the UI. + :param str documentation: The human-readable doc-comment of this parameter. Will be shown in the UI but can be omitted. + """ + self.label = label + self.documentation = documentation + + +class SignatureInformation(object): + """ + Represents the signature of something callable. A signature + can have a label, like a function-name, a doc-comment, and + a set of parameters. + """ + def __init__(self, label, documentation="", parameters=[]): + """ + Constructs a new SignatureInformation instance. + + :param str label: The label of this signature. Will be shown in the UI. + :param str documentation: The human-readable doc-comment of this signature. Will be shown in the UI but can be omitted. + :param ParameterInformation[] parameters: The parameters of this signature. + """ + self.label = label + self.documentation = documentation + self.parameters = [to_type(parameter, ParameterInformation) for parameter in parameters] + + +class SignatureHelp(object): + """ + Signature help represents the signature of something + callable. There can be multiple signature but only one + active and only one active parameter. + """ + def __init__(self, signatures, activeSignature=0, activeParameter=0): + """ + Constructs a new SignatureHelp instance. + + :param SignatureInformation[] signatures: One or more signatures. + :param int activeSignature: + :param int activeParameter: + """ + self.signatures = [to_type(signature, SignatureInformation) for signature in signatures] + self.activeSignature = activeSignature + self.activeParameter = activeParameter
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/filter_files.py b/continuedev/src/continuedev/libs/util/filter_files.py new file mode 100644 index 00000000..6ebaa274 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/filter_files.py @@ -0,0 +1,33 @@ +import fnmatch +from typing import List + +DEFAULT_IGNORE_DIRS = [ + ".git", + ".vscode", + ".idea", + ".vs", + ".venv", + "env", + ".env", + "node_modules", + "dist", + "build", + "target", + "out", + "bin", + ".pytest_cache", + ".vscode-test", + ".continue", + "__pycache__", +] + +DEFAULT_IGNORE_PATTERNS = DEFAULT_IGNORE_DIRS + list( + filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) +) + + +def should_filter_path( + path: str, ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS +) -> bool: + """Returns whether a file should be filtered""" + return any(fnmatch.fnmatch(path, pattern) for pattern in ignore_patterns) diff --git a/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py new file mode 100644 index 00000000..58a4f728 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py @@ -0,0 +1,56 @@ +from boltons import tbutils + +from ....models.main import Traceback + +PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):" + + +def get_python_traceback(output: str) -> str: + if PYTHON_TRACEBACK_PREFIX in output: + tb_string = output.split(PYTHON_TRACEBACK_PREFIX)[-1] + + # Then need to remove any lines below the traceback. Do this by noticing that + # the last line of the traceback is the first (other than they prefix) that doesn't begin with whitespace + lines = list(filter(lambda x: x.strip() != "", tb_string.splitlines())) + for i in range(len(lines) - 1): + if not lines[i].startswith(" "): + tb_string = "\n".join(lines[: i + 1]) + break + + return PYTHON_TRACEBACK_PREFIX + "\n" + tb_string + elif "SyntaxError" in output: + return "SyntaxError" + output.split("SyntaxError")[-1] + else: + return None + + +def get_javascript_traceback(output: str) -> str: + lines = output.splitlines() + first_line = None + for i in range(len(lines) - 1): + segs = lines[i].split(":") + if ( + len(segs) > 1 + and segs[0] != "" + and segs[1].startswith(" ") + and lines[i + 1].strip().startswith("at") + ): + first_line = lines[i] + break + + if first_line is not None: + return "\n".join(lines[lines.index(first_line) :]) + else: + return None + + +def parse_python_traceback(tb_string: str) -> Traceback: + # Remove anchor lines - tbutils doesn't always get them right + tb_string = "\n".join( + filter( + lambda x: x.strip().replace("~", "").replace("^", "") != "", + tb_string.splitlines(), + ) + ) + exc = tbutils.ParsedException.from_string(tb_string) + return Traceback.from_tbutil_parsed_exc(exc) diff --git a/continuedev/src/continuedev/libs/util/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback_parsers.py deleted file mode 100644 index 2b164c0f..00000000 --- a/continuedev/src/continuedev/libs/util/traceback_parsers.py +++ /dev/null @@ -1,30 +0,0 @@ -PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):" - - -def get_python_traceback(output: str) -> str: - if PYTHON_TRACEBACK_PREFIX in output: - return PYTHON_TRACEBACK_PREFIX + output.split(PYTHON_TRACEBACK_PREFIX)[-1] - elif "SyntaxError" in output: - return "SyntaxError" + output.split("SyntaxError")[-1] - else: - return None - - -def get_javascript_traceback(output: str) -> str: - lines = output.splitlines() - first_line = None - for i in range(len(lines) - 1): - segs = lines[i].split(":") - if ( - len(segs) > 1 - and segs[0] != "" - and segs[1].startswith(" ") - and lines[i + 1].strip().startswith("at") - ): - first_line = lines[i] - break - - if first_line is not None: - return "\n".join(lines[lines.index(first_line) :]) - else: - return None diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index de426282..514337bf 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -136,6 +136,11 @@ class FileSystem(AbstractModel): """Apply edit to filesystem, calculate the reverse edit, and return and EditDiff""" raise NotImplementedError + @abstractmethod + def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: + """List the contents of a directory""" + raise NotImplementedError + @classmethod def read_range_in_str(self, s: str, r: Range) -> str: lines = s.split("\n")[r.start.line : r.end.line + 1] @@ -312,6 +317,18 @@ class RealFileSystem(FileSystem): self.write(edit.filepath, new_content) return diff + def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: + """List the contents of a directory""" + if recursive: + # Walk + paths = [] + for root, dirs, files in os.walk(path): + for f in files: + paths.append(os.path.join(root, f)) + + return paths + return list(map(lambda x: os.path.join(path, x), os.listdir(path))) + class VirtualFileSystem(FileSystem): """A simulated filesystem from a mapping of filepath to file contents.""" @@ -363,5 +380,16 @@ class VirtualFileSystem(FileSystem): self.write(edit.filepath, new_content) return EditDiff(edit=edit, original=original) + def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: + """List the contents of a directory""" + if recursive: + for filepath in self.files: + if filepath.startswith(path): + yield filepath + + for filepath in self.files: + if filepath.startswith(path) and "/" not in filepath[len(path) :]: + yield filepath + # TODO: Uniform errors thrown by any FileSystem subclass. diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index d442a415..a8b4aab0 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -50,6 +50,11 @@ class Position(BaseModel): return sum(map(len, lines[: self.line])) + self.character +class PositionInFile(BaseModel): + position: Position + filepath: str + + class Range(BaseModel): """A range in a file. 0-indexed.""" diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 9ac28aa2..06d22348 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -5,6 +5,7 @@ from typing import List from ...core.context import ContextProvider from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.sdk import ContinueSDK +from ...libs.util.filter_files import DEFAULT_IGNORE_PATTERNS from .util import remove_meilisearch_disallowed_chars MAX_SIZE_IN_CHARS = 25_000 @@ -18,36 +19,13 @@ async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str: return None -DEFAULT_IGNORE_DIRS = [ - ".git", - ".vscode", - ".idea", - ".vs", - ".venv", - "env", - ".env", - "node_modules", - "dist", - "build", - "target", - "out", - "bin", - ".pytest_cache", - ".vscode-test", - ".continue", - "__pycache__", -] - - class FileContextProvider(ContextProvider): """ The FileContextProvider is a ContextProvider that allows you to search files in the open workspace. """ title = "file" - ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + list( - filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) - ) + ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS async def start(self, *args): await super().start(*args) diff --git a/continuedev/src/continuedev/plugins/policies/commit.py b/continuedev/src/continuedev/plugins/policies/commit.py new file mode 100644 index 00000000..2fa43676 --- /dev/null +++ b/continuedev/src/continuedev/plugins/policies/commit.py @@ -0,0 +1,77 @@ +# An agent that makes a full commit in the background +# Plans +# Write code +# Reviews code +# Cleans up + +# It's important that agents are configurable, because people need to be able to specify +# which hooks they want to run. Specific linter, run tests, etc. +# And all of this can be easily specified in the Policy. + + +from textwrap import dedent +from typing import Literal + +from ...core.config import ContinueConfig +from ...core.main import History, Policy, Step +from ...core.observation import TextObservation +from ...core.sdk import ContinueSDK + + +class PlanStep(Step): + user_input: str + + _prompt = dedent( + """\ + You were given the following instructions: "{user_input}". + + Create a plan for how you will complete the task. + + Here are relevant files: + + {relevant_files} + + Your plan will include: + 1. A high-level description of how you are going to accomplish the task + 2. A list of which files you will edit + 3. A description of what you will change in each file + """ + ) + + async def run(self, sdk: ContinueSDK): + plan = await sdk.models.default.complete( + self._prompt.format( + {"user_input": self.user_input, "relevant_files": "TODO"} + ) + ) + return TextObservation(text=plan) + + +class WriteCommitStep(Step): + async def run(self, sdk: ContinueSDK): + pass + + +class ReviewCodeStep(Step): + async def run(self, sdk: ContinueSDK): + pass + + +class CleanupStep(Step): + async def run(self, sdk: ContinueSDK): + pass + + +class CommitPolicy(Policy): + user_input: str + + current_step: Literal["plan", "write", "review", "cleanup"] = "plan" + + def next(self, config: ContinueConfig, history: History) -> Step: + if history.get_current() is None: + return ( + PlanStep(user_input=self.user_input) + >> WriteCommitStep() + >> ReviewCodeStep() + >> CleanupStep() + ) diff --git a/continuedev/src/continuedev/plugins/policies/headless.py b/continuedev/src/continuedev/plugins/policies/headless.py new file mode 100644 index 00000000..56ebe31f --- /dev/null +++ b/continuedev/src/continuedev/plugins/policies/headless.py @@ -0,0 +1,18 @@ +from ...core.config import ContinueConfig +from ...core.main import History, Policy, Step +from ...core.observation import TextObservation +from ...plugins.steps.core.core import ShellCommandsStep +from ...plugins.steps.on_traceback import DefaultOnTracebackStep + + +class HeadlessPolicy(Policy): + command: str + + def next(self, config: ContinueConfig, history: History) -> Step: + if history.get_current() is None: + return ShellCommandsStep(cmds=[self.command]) + observation = history.get_current().observation + if isinstance(observation, TextObservation): + return DefaultOnTracebackStep(output=observation.text) + + return None diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5325a918..740c75bc 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,5 +1,6 @@ # These steps are depended upon by ContinueSDK import difflib +import subprocess import traceback from textwrap import dedent from typing import Any, Coroutine, List, Optional, Union @@ -112,52 +113,22 @@ class ShellCommandsStep(Step): ) async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd - - for cmd in self.cmds: - output = await sdk.ide.runCommand(cmd) - if ( - self.handle_error - and output is not None - and output_contains_error(output) - ): - suggestion = await sdk.models.medium.complete( - dedent( - f"""\ - While running the command `{cmd}`, the following error occurred: - - ```ascii - {output} - ``` - - This is a brief summary of the error followed by a suggestion on how it can be fixed:""" - ), - with_history=await sdk.get_chat_context(), - ) - - sdk.raise_exception( - title="Error while running query", - message=output, - with_step=MessageStep( - name=f"Suggestion to solve error {AI_ASSISTED_STRING}", - message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.", - ), - ) - - return TextObservation(text=output) - - # process = subprocess.Popen( - # '/bin/bash', stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=cwd) + process = subprocess.Popen( + "/bin/bash", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + cwd=self.cwd or sdk.ide.workspace_directory, + ) - # stdin_input = "\n".join(self.cmds) - # out, err = process.communicate(stdin_input.encode()) + stdin_input = "\n".join(self.cmds) + out, err = process.communicate(stdin_input.encode()) - # # If it fails, return the error - # if err is not None and err != "": - # self._err_text = err - # return TextObservation(text=err) + # If it fails, return the error + if err is not None and err != "": + self._err_text = err + return TextObservation(text=err) - # return None + return None class DefaultModelEditCodeStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py index 1d135b3e..7ceefe9b 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py +++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py @@ -8,6 +8,10 @@ class ImplementAbstractMethodStep(Step): class_name: str async def run(self, sdk: ContinueSDK): + if sdk.lsp is None: + self.description = "Language Server Protocol is not enabled" + return + implementations = await sdk.lsp.go_to_implementations(self.class_name) for implementation in implementations: diff --git a/continuedev/src/continuedev/plugins/steps/on_traceback.py b/continuedev/src/continuedev/plugins/steps/on_traceback.py index 63bae805..6b75d726 100644 --- a/continuedev/src/continuedev/plugins/steps/on_traceback.py +++ b/continuedev/src/continuedev/plugins/steps/on_traceback.py @@ -1,15 +1,28 @@ import os +from textwrap import dedent +from typing import Dict, List, Optional, Tuple from ...core.main import ChatMessage, Step from ...core.sdk import ContinueSDK -from ...libs.util.traceback_parsers import ( +from ...libs.util.filter_files import should_filter_path +from ...libs.util.traceback.traceback_parsers import ( get_javascript_traceback, get_python_traceback, + parse_python_traceback, ) +from ...models.filesystem import RangeInFile +from ...models.main import Range, Traceback, TracebackFrame from .chat import SimpleChatStep from .core.core import UserInputStep +def extract_traceback_str(output: str) -> str: + tb = output.strip() + for tb_parser in [get_python_traceback, get_javascript_traceback]: + if parsed_tb := tb_parser(tb): + return parsed_tb + + class DefaultOnTracebackStep(Step): output: str name: str = "Help With Traceback" @@ -38,11 +51,11 @@ class DefaultOnTracebackStep(Step): # And this function is where you can get arbitrarily fancy about adding context async def run(self, sdk: ContinueSDK): - tb = self.output.strip() - for tb_parser in [get_python_traceback, get_javascript_traceback]: - if parsed_tb := tb_parser(tb): - tb = parsed_tb - break + if get_python_traceback(self.output) is not None: + await sdk.run_step(SolvePythonTracebackStep(output=self.output)) + return + + tb = extract_traceback_str(self.output) tb_first_last_lines = ( ("\n".join(tb.split("\n")[:3]) + "\n...\n" + "\n".join(tb.split("\n")[-3:])) @@ -57,3 +70,139 @@ class DefaultOnTracebackStep(Step): ) ) await sdk.run_step(SimpleChatStep(name="Help With Traceback")) + + +def filter_frames(frames: List[TracebackFrame]) -> List[TracebackFrame]: + """Filter out frames that are not relevant to the user's code.""" + return list(filter(lambda x: should_filter_path(x.filepath), frames)) + + +def find_external_call( + frames: List[TracebackFrame], +) -> Optional[Tuple[TracebackFrame, TracebackFrame]]: + """Moving up from the bottom of the stack, if the frames are not user code, then find the last frame before it becomes user code.""" + if not should_filter_path(frames[-1].filepath): + # No external call, error comes directly from user code + return None + + for i in range(len(frames) - 2, -1, -1): + if not should_filter_path(frames[i].filepath): + return frames[i], frames[i + 1] + + +def get_func_source_for_frame(frame: Dict) -> str: + """Get the source for the function called in the frame.""" + pass + + +async def fetch_docs_for_external_call(external_call: Dict, next_frame: Dict) -> str: + """Fetch docs for the external call.""" + pass + + +class SolvePythonTracebackStep(Step): + output: str + name: str = "Solve Traceback" + hide: bool = True + + async def external_call_prompt( + self, sdk: ContinueSDK, external_call: Tuple[Dict, Dict], tb_string: str + ) -> str: + external_call, next_frame = external_call + source_line = external_call["source_line"] + external_func_source = get_func_source_for_frame(next_frame) + docs = await fetch_docs_for_external_call(external_call, next_frame) + + prompt = dedent( + f"""\ + I got the following error: + + {tb_string} + + I tried to call an external library like this: + + ```python + {source_line} + ``` + + This is the definition of the function I tried to call: + + ```python + {external_func_source} + ``` + + Here's the documentation for the external library I tried to call: + + {docs} + + Explain how to fix the error. + """ + ) + + return prompt + + async def normal_traceback_prompt( + self, sdk: ContinueSDK, tb: Traceback, tb_string: str + ) -> str: + function_bodies = await get_functions_from_traceback(tb, sdk) + + prompt = ( + "Here are the functions from the traceback (most recent call last):\n\n" + ) + for i, function_body in enumerate(function_bodies): + prompt += f'File "{tb.frames[i].filepath}", line {tb.frames[i].lineno}, in {tb.frames[i].function}\n\n```python\n{function_body or tb.frames[i].code}\n```\n\n' + + prompt += ( + "Here is the traceback:\n\n```\n" + + tb_string + + "\n```\n\nExplain how to fix the error." + ) + + return prompt + + async def run(self, sdk: ContinueSDK): + tb_string = get_python_traceback(self.output) + tb = parse_python_traceback(tb_string) + + if external_call := find_external_call(tb.frames): + prompt = await self.external_call_prompt(sdk, external_call, tb_string) + else: + prompt = await self.normal_traceback_prompt(sdk, tb, tb_string) + + await sdk.run_step( + UserInputStep( + description="Solving stack trace", + user_input=prompt, + ) + ) + await sdk.run_step(SimpleChatStep(name="Help With Traceback")) + + +async def get_function_body(frame: TracebackFrame, sdk: ContinueSDK) -> Optional[str]: + """Get the function body from the traceback frame.""" + if sdk.lsp is None: + return None + + document_symbols = sdk.lsp.get_symbols(frame.filepath) + for symbol in document_symbols: + if symbol.name == frame.function: + r = symbol.location.range + return await sdk.ide.readRangeInFile( + RangeInFile( + filepath=frame.filepath, + range=Range.from_shorthand( + r.start.line, r.start.character, r.end.line, r.end.character + ), + ) + ) + return None + + +async def get_functions_from_traceback(tb: Traceback, sdk: ContinueSDK) -> List[str]: + """Get the function bodies from the traceback.""" + function_bodies = [] + for frame in tb.frames: + if frame.function: + function_bodies.append(await get_function_body(frame, sdk)) + + return function_bodies diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 11099494..6aae8cc5 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -117,10 +117,15 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool: await asyncio.sleep(frequency) +meilisearch_process = None + + async def start_meilisearch(): """ Starts the MeiliSearch server, wait for it. """ + global meilisearch_process + serverPath = getServerFolderPath() # Check if MeiliSearch is installed, if not download @@ -130,7 +135,7 @@ async def start_meilisearch(): if not await check_meilisearch_running() or not was_already_installed: logger.debug("Starting MeiliSearch...") binary_name = "meilisearch" if os.name == "nt" else "./meilisearch" - subprocess.Popen( + meilisearch_process = subprocess.Popen( [binary_name, "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, @@ -139,3 +144,14 @@ async def start_meilisearch(): start_new_session=True, shell=True, ) + + +def stop_meilisearch(): + """ + Stops the MeiliSearch server. + """ + global meilisearch_process + if meilisearch_process is not None: + meilisearch_process.terminate() + meilisearch_process.wait() + meilisearch_process = None diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index f0c33929..f0080104 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, WebSocket from fastapi.websockets import WebSocketState from ..core.autopilot import Autopilot +from ..core.config import ContinueConfig from ..core.main import FullState from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger @@ -57,7 +58,10 @@ class SessionManager: return self.sessions[session_id] async def new_session( - self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None + self, + ide: AbstractIdeProtocolServer, + session_id: Optional[str] = None, + config: Optional[ContinueConfig] = None, ) -> Session: logger.debug(f"New session: {session_id}") @@ -85,7 +89,7 @@ class SessionManager: # Start the autopilot (must be after session is added to sessions) and the policy try: - await autopilot.start(full_state=full_state) + await autopilot.start(full_state=full_state, config=config) except Exception as e: await ide.on_error(e) diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py index 91ddd33f..f4aea1fb 100644 --- a/continuedev/src/continuedev/tests/llm_test.py +++ b/continuedev/src/continuedev/tests/llm_test.py @@ -13,10 +13,11 @@ from continuedev.libs.llm.openai import OpenAI from continuedev.libs.llm.together import TogetherLLM from continuedev.libs.util.count_tokens import DEFAULT_ARGS from continuedev.tests.util.openai_mock import start_openai +from continuedev.tests.util.prompts import tokyo_test_pair load_dotenv() -TEST_PROMPT = "Output a single word, that being the capital of Japan:" + SPEND_MONEY = True @@ -65,9 +66,9 @@ class TestBaseLLM: if self.llm.__class__.__name__ == "LLM": pytest.skip("Skipping abstract LLM") - resp = await self.llm.complete(TEST_PROMPT, temperature=0.0) + resp = await self.llm.complete(tokyo_test_pair[0], temperature=0.0) assert isinstance(resp, str) - assert resp.strip().lower() == "tokyo" + assert resp.strip().lower() == tokyo_test_pair[1] @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") @async_test @@ -79,7 +80,9 @@ class TestBaseLLM: role = None async for chunk in self.llm.stream_chat( messages=[ - ChatMessage(role="user", content=TEST_PROMPT, summary=TEST_PROMPT) + ChatMessage( + role="user", content=tokyo_test_pair[0], summary=tokyo_test_pair[0] + ) ], temperature=0.0, ): @@ -90,7 +93,7 @@ class TestBaseLLM: role = chunk["role"] assert role == "assistant" - assert completion.strip().lower() == "tokyo" + assert completion.strip().lower() == tokyo_test_pair[1] @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") @async_test @@ -99,11 +102,13 @@ class TestBaseLLM: pytest.skip("Skipping abstract LLM") completion = "" - async for chunk in self.llm.stream_complete(TEST_PROMPT, temperature=0.0): + async for chunk in self.llm.stream_complete( + tokyo_test_pair[0], temperature=0.0 + ): assert isinstance(chunk, str) completion += chunk - assert completion.strip().lower() == "tokyo" + assert completion.strip().lower() == tokyo_test_pair[1] class TestOpenAI(TestBaseLLM): @@ -129,7 +134,7 @@ class TestOpenAI(TestBaseLLM): "Output a single word, that being the capital of Japan:" ) assert isinstance(resp, str) - assert resp.strip().lower() == "tokyo" + assert resp.strip().lower() == tokyo_test_pair[1] class TestGGML(TestBaseLLM): diff --git a/continuedev/src/continuedev/tests/step_test.py b/continuedev/src/continuedev/tests/step_test.py new file mode 100644 index 00000000..a4131e61 --- /dev/null +++ b/continuedev/src/continuedev/tests/step_test.py @@ -0,0 +1,43 @@ +import pytest + +from continuedev.core.config import ContinueConfig +from continuedev.headless import start_headless_session +from continuedev.plugins.steps.core.core import UserInputStep +from continuedev.tests.util.prompts import tokyo_test_pair + +TEST_CONFIG = ContinueConfig() + + +@pytest.mark.asyncio +async def test_step(): + session = await start_headless_session(config=TEST_CONFIG) + + await session.autopilot.run_from_step(UserInputStep(user_input=tokyo_test_pair[0])) + + full_state = await session.autopilot.get_full_state() + assert ( + full_state.history.timeline[-1].step.description.strip().lower() + == tokyo_test_pair[1] + ) + + await session.autopilot.cleanup() + + +# TODO: Test other properties of full_state after the UserInputStep. Also test with other config properties and models, etc... +# so we are sure that UserInputStep works in many cases. One example of a thing to check is that the step following UserInputStep +# is a SimpleChatStep, and that it is not hidden, and that the properties of the node (full_state.history.timeline[-1]) all look good +# (basically, run it, see what you get, then codify this in assertions) + +# TODO: Write tests for other steps: +# - DefaultOnTracebackStep +# - EditHighlightedCodeStep (note that this won't test the rendering in IDE) + +# NOTE: Avoid expensive prompts - not too big a deal, but don't have it generate an entire 100 line file +# If you want to not have llm_test.py spend money, change SPEND_MONEY to False at the +# top of that file, but make sure to put it back to True before committing + +# NOTE: Headless mode uses continuedev.src.continuedev.headless.headless_ide instead of +# VS Code, so many of the methods just pass, or might not act exactly how you expect. +# See the file for reference + +# NOTE: If this is too short or pointless a task, let me know and I'll set up testing of ContextProviders diff --git a/continuedev/src/continuedev/tests/util/config.py b/continuedev/src/continuedev/tests/util/config.py new file mode 100644 index 00000000..73d3aeff --- /dev/null +++ b/continuedev/src/continuedev/tests/util/config.py @@ -0,0 +1,19 @@ +from continuedev.src.continuedev.core.config import ContinueConfig +from continuedev.src.continuedev.core.models import Models +from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI + +config = ContinueConfig( + allow_anonymous_telemetry=False, + models=Models( + default=MaybeProxyOpenAI(api_key="", model="gpt-4"), + medium=MaybeProxyOpenAI( + api_key="", + model="gpt-3.5-turbo", + ), + ), + system_message=None, + temperature=0.5, + custom_commands=[], + slash_commands=[], + context_providers=[], +) diff --git a/continuedev/src/continuedev/tests/util/prompts.py b/continuedev/src/continuedev/tests/util/prompts.py new file mode 100644 index 00000000..f60c926c --- /dev/null +++ b/continuedev/src/continuedev/tests/util/prompts.py @@ -0,0 +1 @@ +tokyo_test_pair = ("Output a single word, that being the capital of Japan:", "tokyo")
\ No newline at end of file @@ -12,8 +12,8 @@ a = Analysis( datas=[ ('continuedev', 'continuedev'), (certifi.where(), 'ca_bundle') - ] + copy_metadata('replicate'), - hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'replicate'], + ], + hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'python-lsp-server', 'replicate'] + copy_metadata('replicate'), hookspath=[], hooksconfig={}, runtime_hooks=[], diff --git a/test.py b/test.py new file mode 100644 index 00000000..5bd57e0e --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import unittest + + +def sort_numbers(numbers): + for i in range(len(numbers)): + for j in range(i + 1, len(numbers)): + if numbers[i] > numbers[j]: + numbers[i], numbers[j] = numbers[j], numbers[i] + return numbers[:-1] # Error here: We're not returning the last number + + +class TestSortNumbers(unittest.TestCase): + def test_sort_numbers(self): + self.assertEqual(sort_numbers([3, 2, 1]), [1, 2, 3]) # This test will fail + self.assertEqual( + sort_numbers([4, 2, 5, 1, 3]), [1, 2, 3, 4, 5] + ) # This test will fail + + +if __name__ == "__main__": + unittest.main() |