diff options
32 files changed, 1750 insertions, 136 deletions
| diff --git a/.vscode/launch.json b/.vscode/launch.json index 08061d13..9ccf4ce7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,6 +37,17 @@        // What about a watch task? - type errors?      },      { +      "name": "Headless", +      "type": "python", +      "request": "launch", +      "module": "continuedev.headless", +      "args": ["--config", "continuedev/config.py"], +      "justMyCode": false, +      "subProcess": false +      // Does it need a build task? +      // What about a watch task? - type errors? +    }, +    {        "name": "Extension (VSCode)",        "type": "extensionHost",        "request": "launch", diff --git a/continuedev/requirements.txt b/continuedev/requirements.txt index 0cf909d5..91120c59 100644 --- a/continuedev/requirements.txt +++ b/continuedev/requirements.txt @@ -22,4 +22,5 @@ socksio==1.0.0  ripgrepy==2.0.0  replicate==0.11.0  bs4==0.0.1 -redbaron==0.9.2
\ No newline at end of file +redbaron==0.9.2 +python-lsp-server==1.2.0
\ No newline at end of file diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index bae82739..8ac7241d 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -17,7 +17,10 @@ from ..libs.util.paths import getSavedContextGroupsPath  from ..libs.util.queue import AsyncSubscriptionQueue  from ..libs.util.strings import remove_quotes_and_escapes  from ..libs.util.telemetry import posthog_logger -from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback +from ..libs.util.traceback.traceback_parsers import ( +    get_javascript_traceback, +    get_python_traceback, +)  from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents  from ..models.main import ContinueBaseModel @@ -32,6 +35,8 @@ from ..plugins.steps.core.core import (  )  from ..plugins.steps.on_traceback import DefaultOnTracebackStep  from ..server.ide_protocol import AbstractIdeProtocolServer +from ..server.meilisearch_server import stop_meilisearch +from .config import ContinueConfig  from .context import ContextManager  from .main import (      Context, @@ -97,8 +102,12 @@ class Autopilot(ContinueBaseModel):      started: bool = False -    async def start(self, full_state: Optional[FullState] = None): -        self.continue_sdk = await ContinueSDK.create(self) +    async def start( +        self, +        full_state: Optional[FullState] = None, +        config: Optional[ContinueConfig] = None, +    ): +        self.continue_sdk = await ContinueSDK.create(self, config=config)          if override_policy := self.continue_sdk.config.policy_override:              self.policy = override_policy @@ -134,6 +143,11 @@ class Autopilot(ContinueBaseModel):          self.started = True +    async def cleanup(self): +        if self.continue_sdk.lsp is not None: +            await self.continue_sdk.lsp.stop() +        stop_meilisearch() +      class Config:          arbitrary_types_allowed = True          keep_untouched = (cached_property,) diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 62e9c690..68b2b17d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -59,3 +59,14 @@ class ContinueConfig(BaseModel):      @validator("temperature", pre=True)      def temperature_validator(cls, v):          return max(0.0, min(1.0, v)) + +    @staticmethod +    def from_filepath(filepath: str) -> "ContinueConfig": +        # Use importlib to load the config file config.py at the given path +        import importlib.util + +        spec = importlib.util.spec_from_file_location("config", filepath) +        config = importlib.util.module_from_spec(spec) +        spec.loader.exec_module(config) + +        return config.config diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py new file mode 100644 index 00000000..5c1f9989 --- /dev/null +++ b/continuedev/src/continuedev/core/lsp.py @@ -0,0 +1,310 @@ +import os +import socket +import subprocess +import threading +from typing import List, Optional + +from pydantic import BaseModel + +from ..libs.lspclient.json_rpc_endpoint import JsonRpcEndpoint +from ..libs.lspclient.lsp_client import LspClient +from ..libs.lspclient.lsp_endpoint import LspEndpoint +from ..libs.lspclient.lsp_structs import Position as LspPosition +from ..libs.lspclient.lsp_structs import SymbolInformation, TextDocumentIdentifier +from ..models.filesystem import RangeInFile +from ..models.main import Position, Range + + +class ReadPipe(threading.Thread): +    def __init__(self, pipe): +        threading.Thread.__init__(self) +        self.pipe = pipe + +    def run(self): +        line = self.pipe.readline().decode("utf-8") +        while line: +            print(line) +            line = self.pipe.readline().decode("utf-8") + + +class SocketFileWrapper: +    def __init__(self, sockfile): +        self.sockfile = sockfile + +    def write(self, data): +        if isinstance(data, bytes): +            data = data.decode("utf-8").replace("\r\n", "\n") +        return self.sockfile.write(data) + +    def read(self, size=-1): +        data = self.sockfile.read(size) +        if isinstance(data, str): +            data = data.replace("\n", "\r\n").encode("utf-8") +        return data + +    def readline(self, size=-1): +        data = self.sockfile.readline(size) +        if isinstance(data, str): +            data = data.replace("\n", "\r\n").encode("utf-8") +        return data + +    def flush(self): +        return self.sockfile.flush() + +    def close(self): +        return self.sockfile.close() + + +def create_json_rpc_endpoint(use_subprocess: Optional[str] = None): +    if use_subprocess is None: +        # Connect to the server +        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +        s.connect(("localhost", 8080)) + +        # Create a file-like object from the socket +        sockfile = s.makefile("rw") +        wrapped_sockfile = SocketFileWrapper(sockfile) +        return JsonRpcEndpoint(wrapped_sockfile, wrapped_sockfile), None + +    else: +        pyls_cmd = use_subprocess.split() +        p = subprocess.Popen( +            pyls_cmd, +            stdin=subprocess.PIPE, +            stdout=subprocess.PIPE, +            stderr=subprocess.PIPE, +        ) +        read_pipe = ReadPipe(p.stderr) +        read_pipe.start() +        return JsonRpcEndpoint(p.stdin, p.stdout), p + + +def filename_to_uri(filename: str) -> str: +    return f"file://{filename}" + + +def uri_to_filename(uri: str) -> str: +    if uri.startswith("file://"): +        return uri.lstrip("file://") +    else: +        return uri + + +def create_lsp_client(workspace_dir: str, use_subprocess: Optional[str] = None): +    json_rpc_endpoint, process = create_json_rpc_endpoint(use_subprocess=use_subprocess) +    lsp_endpoint = LspEndpoint(json_rpc_endpoint) +    lsp_client = LspClient(lsp_endpoint) +    capabilities = { +        "textDocument": { +            "codeAction": {"dynamicRegistration": True}, +            "codeLens": {"dynamicRegistration": True}, +            "colorProvider": {"dynamicRegistration": True}, +            "completion": { +                "completionItem": { +                    "commitCharactersSupport": True, +                    "documentationFormat": ["markdown", "plaintext"], +                    "snippetSupport": True, +                }, +                "completionItemKind": { +                    "valueSet": [ +                        1, +                        2, +                        3, +                        4, +                        5, +                        6, +                        7, +                        8, +                        9, +                        10, +                        11, +                        12, +                        13, +                        14, +                        15, +                        16, +                        17, +                        18, +                        19, +                        20, +                        21, +                        22, +                        23, +                        24, +                        25, +                    ] +                }, +                "contextSupport": True, +                "dynamicRegistration": True, +            }, +            "definition": {"dynamicRegistration": True}, +            "documentHighlight": {"dynamicRegistration": True}, +            "documentLink": {"dynamicRegistration": True}, +            "documentSymbol": { +                "dynamicRegistration": True, +                "symbolKind": { +                    "valueSet": [ +                        1, +                        2, +                        3, +                        4, +                        5, +                        6, +                        7, +                        8, +                        9, +                        10, +                        11, +                        12, +                        13, +                        14, +                        15, +                        16, +                        17, +                        18, +                        19, +                        20, +                        21, +                        22, +                        23, +                        24, +                        25, +                        26, +                    ] +                }, +            }, +            "formatting": {"dynamicRegistration": True}, +            "hover": { +                "contentFormat": ["markdown", "plaintext"], +                "dynamicRegistration": True, +            }, +            "implementation": {"dynamicRegistration": True}, +            "onTypeFormatting": {"dynamicRegistration": True}, +            "publishDiagnostics": {"relatedInformation": True}, +            "rangeFormatting": {"dynamicRegistration": True}, +            "references": {"dynamicRegistration": True}, +            "rename": {"dynamicRegistration": True}, +            "signatureHelp": { +                "dynamicRegistration": True, +                "signatureInformation": { +                    "documentationFormat": ["markdown", "plaintext"] +                }, +            }, +            "synchronization": { +                "didSave": True, +                "dynamicRegistration": True, +                "willSave": True, +                "willSaveWaitUntil": True, +            }, +            "typeDefinition": {"dynamicRegistration": True}, +        }, +        "workspace": { +            "applyEdit": True, +            "configuration": True, +            "didChangeConfiguration": {"dynamicRegistration": True}, +            "didChangeWatchedFiles": {"dynamicRegistration": True}, +            "executeCommand": {"dynamicRegistration": True}, +            "symbol": { +                "dynamicRegistration": True, +                "symbolKind": { +                    "valueSet": [ +                        1, +                        2, +                        3, +                        4, +                        5, +                        6, +                        7, +                        8, +                        9, +                        10, +                        11, +                        12, +                        13, +                        14, +                        15, +                        16, +                        17, +                        18, +                        19, +                        20, +                        21, +                        22, +                        23, +                        24, +                        25, +                        26, +                    ] +                }, +            }, +            "workspaceEdit": {"documentChanges": True}, +            "workspaceFolders": True, +        }, +    } +    root_uri = filename_to_uri(workspace_dir) +    dir_name = os.path.basename(workspace_dir) +    workspace_folders = [{"name": dir_name, "uri": root_uri}] +    lsp_client.initialize( +        None, +        None, +        root_uri, +        None, +        capabilities, +        "off", +        workspace_folders, +    ) +    lsp_client.initialized() +    return lsp_client, process + + +class ContinueLSPClient(BaseModel): +    workspace_dir: str +    lsp_client: LspClient = None +    use_subprocess: Optional[str] = None +    lsp_process: Optional[subprocess.Popen] = None + +    class Config: +        arbitrary_types_allowed = True + +    def dict(self, **kwargs): +        original_dict = super().dict(**kwargs) +        original_dict.pop("lsp_client", None) +        return original_dict + +    async def start(self): +        self.lsp_client, self.lsp_process = create_lsp_client( +            self.workspace_dir, use_subprocess=self.use_subprocess +        ) + +    async def stop(self): +        self.lsp_client.shutdown() +        self.lsp_client.exit() +        if self.lsp_process is not None: +            self.lsp_process.terminate() +            self.lsp_process.wait() +            self.lsp_process = None + +    def goto_definition(self, position: Position, filename: str) -> List[RangeInFile]: +        response = self.lsp_client.definition( +            TextDocumentIdentifier(filename_to_uri(filename)), +            LspPosition(position.line, position.character), +        ) +        return [ +            RangeInFile( +                filepath=uri_to_filename(x.uri), +                range=Range.from_shorthand( +                    x.range.start.line, +                    x.range.start.character, +                    x.range.end.line, +                    x.range.end.character, +                ), +            ) +            for x in response +        ] + +    def get_symbols(self, filepath: str) -> List[SymbolInformation]: +        response = self.lsp_client.documentSymbol( +            TextDocumentIdentifier(filename_to_uri(filepath)) +        ) + +        return response diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 9ff6612c..7dca600d 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,8 +1,9 @@  import os  import traceback -from typing import Coroutine, Union +from typing import Coroutine, List, Optional, Union  from ..libs.llm import LLM +from ..libs.util.create_async_task import create_async_task  from ..libs.util.logging import logger  from ..libs.util.paths import getConfigFilePath  from ..libs.util.telemetry import posthog_logger @@ -16,11 +17,18 @@ from ..models.filesystem_edit import (      FileSystemEdit,  )  from ..models.main import Range -from ..plugins.steps.core.core import * -from ..plugins.steps.core.core import DefaultModelEditCodeStep +from ..plugins.steps.core.core import ( +    DefaultModelEditCodeStep, +    FileSystemEditStep, +    MessageStep, +    RangeInFileWithContents, +    ShellCommandsStep, +    WaitForUserConfirmationStep, +)  from ..server.ide_protocol import AbstractIdeProtocolServer  from .abstract_sdk import AbstractContinueSDK  from .config import ContinueConfig +from .lsp import ContinueLSPClient  from .main import (      ChatMessage,      Context, @@ -42,6 +50,7 @@ class ContinueSDK(AbstractContinueSDK):      ide: AbstractIdeProtocolServer      models: Models +    lsp: Optional[ContinueLSPClient] = None      context: Context      config: ContinueConfig      __autopilot: Autopilot @@ -52,13 +61,14 @@ class ContinueSDK(AbstractContinueSDK):          self.context = autopilot.context      @classmethod -    async def create(cls, autopilot: Autopilot) -> "ContinueSDK": +    async def create( +        cls, autopilot: Autopilot, config: Optional[ContinueConfig] = None +    ) -> "ContinueSDK":          sdk = ContinueSDK(autopilot)          autopilot.continue_sdk = sdk          try: -            config = sdk._load_config_dot_py() -            sdk.config = config +            sdk.config = config or sdk._load_config_dot_py()          except Exception as e:              logger.error(f"Failed to load config.py: {traceback.format_exception(e)}") @@ -78,9 +88,26 @@ class ContinueSDK(AbstractContinueSDK):              )              await sdk.ide.setFileOpen(getConfigFilePath()) +        # Start models          sdk.models = sdk.config.models          await sdk.models.start(sdk) +        # Start LSP +        async def start_lsp(): +            try: +                sdk.lsp = ContinueLSPClient( +                    workspace_dir=sdk.ide.workspace_directory, +                    use_subprocess="python3.10 -m pylsp", +                ) +                await sdk.lsp.start() +            except: +                logger.warning("Failed to start LSP client", exc_info=True) +                sdk.lsp = None + +        create_async_task( +            start_lsp(), on_error=lambda e: logger.error("Failed to setup LSP: %s", e) +        ) +          # When the config is loaded, setup posthog logger          posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) @@ -207,19 +234,13 @@ class ContinueSDK(AbstractContinueSDK):      _last_valid_config: ContinueConfig = None      def _load_config_dot_py(self) -> ContinueConfig: -        # Use importlib to load the config file config.py at the given path          path = getConfigFilePath() - -        import importlib.util - -        spec = importlib.util.spec_from_file_location("config", path) -        config = importlib.util.module_from_spec(spec) -        spec.loader.exec_module(config) -        self._last_valid_config = config.config +        config = ContinueConfig.from_filepath(path) +        self._last_valid_config = config          logger.debug("Loaded Continue config file from %s", path) -        return config.config +        return config      def get_code_context(          self, only_editing: bool = False diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py new file mode 100644 index 00000000..4e46409a --- /dev/null +++ b/continuedev/src/continuedev/headless/__init__.py @@ -0,0 +1,41 @@ +import asyncio +from typing import Optional, Union + +import typer + +from ..core.config import ContinueConfig +from ..server.session_manager import Session, session_manager +from .headless_ide import LocalIdeProtocol + +app = typer.Typer() + + +async def start_headless_session( +    config: Optional[Union[str, ContinueConfig]] = None +) -> Session: +    if config is not None: +        if isinstance(config, str): +            config: ContinueConfig = ContinueConfig.from_filepath(config) + +    ide = LocalIdeProtocol() +    return await session_manager.new_session(ide, config=config) + + +async def async_main(config: Optional[str] = None): +    await start_headless_session(config=config) + + +@app.command() +def main( +    config: Optional[str] = typer.Option( +        None, help="The path to the configuration file" +    ) +): +    loop = asyncio.get_event_loop() +    loop.run_until_complete(async_main(config)) +    tasks = asyncio.all_tasks(loop) +    loop.run_until_complete(asyncio.gather(*tasks)) + + +if __name__ == "__main__": +    app() diff --git a/continuedev/src/continuedev/headless/headless_ide.py b/continuedev/src/continuedev/headless/headless_ide.py new file mode 100644 index 00000000..088da2c9 --- /dev/null +++ b/continuedev/src/continuedev/headless/headless_ide.py @@ -0,0 +1,181 @@ +import os +import subprocess +import uuid +from typing import Any, Callable, Coroutine, List, Optional + +from dotenv import load_dotenv +from fastapi import WebSocket + +from ..models.filesystem import ( +    FileSystem, +    RangeInFile, +    RangeInFileWithContents, +    RealFileSystem, +) +from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit +from ..server.ide_protocol import AbstractIdeProtocolServer + +load_dotenv() + + +def get_mac_address(): +    mac_num = hex(uuid.getnode()).replace("0x", "").upper() +    mac = "-".join(mac_num[i : i + 2] for i in range(0, 11, 2)) +    return mac + + +class LocalIdeProtocol(AbstractIdeProtocolServer): +    websocket: WebSocket = None +    session_id: Optional[str] +    workspace_directory: str = os.getcwd() +    unique_id: str = get_mac_address() + +    filesystem: FileSystem = RealFileSystem() + +    async def handle_json(self, data: Any): +        """Handle a json message""" +        pass + +    def showSuggestion(self, file_edit: FileEdit): +        """Show a suggestion to the user""" +        pass + +    async def setFileOpen(self, filepath: str, open: bool = True): +        """Set whether a file is open""" +        pass + +    async def showMessage(self, message: str): +        """Show a message to the user""" +        print(message) + +    async def showVirtualFile(self, name: str, contents: str): +        """Show a virtual file""" +        pass + +    async def setSuggestionsLocked(self, filepath: str, locked: bool = True): +        """Set whether suggestions are locked""" +        pass + +    async def getSessionId(self): +        """Get a new session ID""" +        pass + +    async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: +        """Show suggestions to the user and wait for a response""" +        pass + +    def onAcceptRejectSuggestion(self, accepted: bool): +        """Called when the user accepts or rejects a suggestion""" +        pass + +    def onFileSystemUpdate(self, update: FileSystemEdit): +        """Called when a file system update is received""" +        pass + +    def onCloseGUI(self, session_id: str): +        """Called when a GUI is closed""" +        pass + +    def onOpenGUIRequest(self): +        """Called when a GUI is requested to be opened""" +        pass + +    async def getOpenFiles(self) -> List[str]: +        """Get a list of open files""" +        pass + +    async def getVisibleFiles(self) -> List[str]: +        """Get a list of visible files""" +        pass + +    async def getHighlightedCode(self) -> List[RangeInFile]: +        """Get a list of highlighted code""" +        pass + +    async def readFile(self, filepath: str) -> str: +        """Read a file""" +        return self.filesystem.read(filepath) + +    async def readRangeInFile(self, range_in_file: RangeInFile) -> str: +        """Read a range in a file""" +        return self.filesystem.read_range_in_file(range_in_file) + +    async def editFile(self, edit: FileEdit): +        """Edit a file""" +        self.filesystem.apply_file_edit(edit) + +    async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: +        """Apply a file edit""" +        return self.filesystem.apply_edit(edit) + +    async def saveFile(self, filepath: str): +        """Save a file""" +        pass + +    async def getUserSecret(self, key: str): +        """Get a user secret""" +        return os.environ.get(key) + +    async def highlightCode(self, range_in_file: RangeInFile, color: str): +        """Highlight code""" +        pass + +    async def runCommand(self, command: str) -> str: +        """Run a command using subprocess (don't pass, actually implement)""" +        return subprocess.check_output(command, shell=True).decode("utf-8") + +    def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): +        """Called when highlighted code is updated""" +        pass + +    def onDeleteAtIndex(self, index: int): +        """Called when a step is deleted at a given index""" +        pass + +    async def showDiff(self, filepath: str, replacement: str, step_index: int): +        """Show a diff""" +        pass + +    def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): +        """Subscribe to files created event""" +        pass + +    def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): +        """Subscribe to files deleted event""" +        pass + +    def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): +        """Subscribe to files renamed event""" +        pass + +    def subscribeToFileSaved(self, callback: Callable[[str, str], None]): +        """Subscribe to file saved event""" +        pass + +    def onFilesCreated(self, filepaths: List[str]): +        """Called when files are created""" +        pass + +    def onFilesDeleted(self, filepaths: List[str]): +        """Called when files are deleted""" +        pass + +    def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): +        """Called when files are renamed""" +        pass + +    def onFileSaved(self, filepath: str, contents: str): +        """Called when a file is saved""" +        pass + +    async def fileExists(self, filepath: str) -> Coroutine[Any, Any, str]: +        """Check if a file exists""" +        return self.filesystem.exists(filepath) + +    async def getTerminalContents(self) -> Coroutine[Any, Any, str]: +        return "" + +    async def listDirectoryContents( +        self, directory: str, recursive: bool = False +    ) -> List[str]: +        return self.filesystem.list_directory_contents(directory, recursive=recursive) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 03c9cce4..1b91ec43 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -4,6 +4,7 @@ from typing import Callable, Optional  import aiohttp  from ..llm import LLM +from ..util.logging import logger  from .prompts.chat import llama2_template_messages  from .prompts.edit import simplified_edit_prompt @@ -59,7 +60,13 @@ class TogetherLLM(LLM):                          if chunk.strip() != "":                              if chunk.startswith("data: "):                                  chunk = chunk[6:] -                            json_chunk = json.loads(chunk) +                            if chunk == "[DONE]": +                                break +                            try: +                                json_chunk = json.loads(chunk) +                            except Exception as e: +                                logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}") +                                continue                              if "choices" in json_chunk:                                  yield json_chunk["choices"][0]["text"] diff --git a/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py new file mode 100644 index 00000000..80c51000 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py @@ -0,0 +1,82 @@ +from __future__ import print_function + +import json +import re +import threading + +JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}" +JSON_RPC_RES_REGEX = "Content-Length: ([0-9]*)\r\n" +# TODO: add content-type + + +class MyEncoder(json.JSONEncoder): +    """ +    Encodes an object in JSON +    """ + +    def default(self, o): +        return o.__dict__ + + +class JsonRpcEndpoint(object): +    """ +    Thread safe JSON RPC endpoint implementation. Responsible to recieve and send JSON RPC messages, as described in the +    protocol. More information can be found: https://www.jsonrpc.org/ +    """ + +    def __init__(self, stdin, stdout): +        self.stdin = stdin +        self.stdout = stdout +        self.read_lock = threading.Lock() +        self.write_lock = threading.Lock() + +    @staticmethod +    def __add_header(json_string): +        """ +        Adds a header for the given json string + +        :param str json_string: The string +        :return: the string with the header +        """ +        return JSON_RPC_REQ_FORMAT.format( +            json_string_len=len(json_string), json_string=json_string +        ) + +    def send_request(self, message): +        """ +        Sends the given message. + +        :param dict message: The message to send. +        """ +        json_string = json.dumps(message, cls=MyEncoder) +        # print("sending:", json_string) +        jsonrpc_req = self.__add_header(json_string) +        with self.write_lock: +            self.stdin.write(jsonrpc_req.encode()) +            self.stdin.flush() + +    def recv_response(self): +        """ +        Recives a message. + +         :return: a message +        """ +        with self.read_lock: +            line = self.stdout.readline() +            if not line: +                return None +            # print(line) +            line = line.decode() +            # TODO: handle content type as well. +            match = re.match(JSON_RPC_RES_REGEX, line) +            if match is None or not match.groups(): +                raise RuntimeError("Bad header: " + line) +            size = int(match.groups()[0]) +            line = self.stdout.readline() +            if not line: +                return None +            line = line.decode() +            # if line != "\r\n": +            #     raise RuntimeError("Bad header: missing newline") +            jsonrpc_res = self.stdout.read(size + 2) +            return json.loads(jsonrpc_res) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_client.py b/continuedev/src/continuedev/libs/lspclient/lsp_client.py new file mode 100644 index 00000000..fe2db6ad --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_client.py @@ -0,0 +1,150 @@ +from .lsp_structs import Location, SignatureHelp, SymbolInformation + + +class LspClient(object): +    def __init__(self, lsp_endpoint): +        """ +        Constructs a new LspClient instance. + +        :param lsp_endpoint: TODO +        """ +        self.lsp_endpoint = lsp_endpoint + +    def initialize( +        self, +        processId, +        rootPath, +        rootUri, +        initializationOptions, +        capabilities, +        trace, +        workspaceFolders, +    ): +        """ +        The initialize request is sent as the first request from the client to the server. If the server receives a request or notification +        before the initialize request it should act as follows: + +        1. For a request the response should be an error with code: -32002. The message can be picked by the server. +        2. Notifications should be dropped, except for the exit notification. This will allow the exit of a server without an initialize request. + +        Until the server has responded to the initialize request with an InitializeResult, the client must not send any additional requests or +        notifications to the server. In addition the server is not allowed to send any requests or notifications to the client until it has responded +        with an InitializeResult, with the exception that during the initialize request the server is allowed to send the notifications window/showMessage, +        window/logMessage and telemetry/event as well as the window/showMessageRequest request to the client. + +        The initialize request may only be sent once. + +        :param int processId: The process Id of the parent process that started the server. Is null if the process has not been started by another process. +                                If the parent process is not alive then the server should exit (see exit notification) its process. +        :param str rootPath: The rootPath of the workspace. Is null if no folder is open. Deprecated in favour of rootUri. +        :param DocumentUri rootUri: The rootUri of the workspace. Is null if no folder is open. If both `rootPath` and `rootUri` are set +                                    `rootUri` wins. +        :param any initializationOptions: User provided initialization options. +        :param ClientCapabilities capabilities: The capabilities provided by the client (editor or tool). +        :param Trace trace: The initial trace setting. If omitted trace is disabled ('off'). +        :param list workspaceFolders: The workspace folders configured in the client when the server starts. This property is only available if the client supports workspace folders. +                                        It can be `null` if the client supports workspace folders but none are configured. +        """ +        self.lsp_endpoint.start() +        return self.lsp_endpoint.call_method( +            "initialize", +            processId=processId, +            rootPath=rootPath, +            rootUri=rootUri, +            initializationOptions=initializationOptions, +            capabilities=capabilities, +            trace=trace, +            workspaceFolders=workspaceFolders, +        ) + +    def initialized(self): +        """ +        The initialized notification is sent from the client to the server after the client received the result of the initialize request +        but before the client is sending any other request or notification to the server. The server can use the initialized notification +        for example to dynamically register capabilities. The initialized notification may only be sent once. +        """ +        self.lsp_endpoint.send_notification("initialized") + +    def shutdown(self): +        """ +        The initialized notification is sent from the client to the server after the client received the result of the initialize request +        but before the client is sending any other request or notification to the server. The server can use the initialized notification +        for example to dynamically register capabilities. The initialized notification may only be sent once. +        """ +        self.lsp_endpoint.stop() +        return self.lsp_endpoint.call_method("shutdown") + +    def exit(self): +        """ +        The initialized notification is sent from the client to the server after the client received the result of the initialize request +        but before the client is sending any other request or notification to the server. The server can use the initialized notification +        for example to dynamically register capabilities. The initialized notification may only be sent once. +        """ +        self.lsp_endpoint.send_notification("exit") + +    def didOpen(self, textDocument): +        """ +        The document open notification is sent from the client to the server to signal newly opened text documents. The document's truth is +        now managed by the client and the server must not try to read the document's truth using the document's uri. Open in this sense +        means it is managed by the client. It doesn't necessarily mean that its content is presented in an editor. An open notification must +        not be sent more than once without a corresponding close notification send before. This means open and close notification must be +        balanced and the max open count for a particular textDocument is one. Note that a server's ability to fulfill requests is independent +        of whether a text document is open or closed. + +        The DidOpenTextDocumentParams contain the language id the document is associated with. If the language Id of a document changes, the +        client needs to send a textDocument/didClose to the server followed by a textDocument/didOpen with the new language id if the server +        handles the new language id as well. + +        :param TextDocumentItem textDocument: The initial trace setting. If omitted trace is disabled ('off'). +        """ +        return self.lsp_endpoint.send_notification( +            "textDocument/didOpen", textDocument=textDocument +        ) + +    def documentSymbol(self, textDocument): +        """ +        The document symbol request is sent from the client to the server to return a flat list of all symbols found in a given text document. +        Neither the symbol's location range nor the symbol's container name should be used to infer a hierarchy. + +        :param TextDocumentItem textDocument: The text document. +        """ +        result_dict = self.lsp_endpoint.call_method( +            "textDocument/documentSymbol", textDocument=textDocument +        ) +        return [SymbolInformation(**sym) for sym in result_dict] + +    def definition(self, textDocument, position): +        """ +        The goto definition request is sent from the client to the server to resolve the definition location of a symbol at a given text document position. + +        :param TextDocumentItem textDocument: The text document. +        :param Position position: The position inside the text document.. +        """ +        result_dict = self.lsp_endpoint.call_method( +            "textDocument/definition", textDocument=textDocument, position=position +        ) +        return [Location(**l) for l in result_dict] + +    def typeDefinition(self, textDocument, position): +        """ +        The goto type definition request is sent from the client to the server to resolve the type definition location of a symbol at a given text document position. + +        :param TextDocumentItem textDocument: The text document. +        :param Position position: The position inside the text document.. +        """ +        result_dict = self.lsp_endpoint.call_method( +            "textDocument/definition", textDocument=textDocument, position=position +        ) +        return [Location(**l) for l in result_dict] + +    def signatureHelp(self, textDocument, position): +        """ +        The signature help request is sent from the client to the server to request signature information at a given cursor position. + +        :param TextDocumentItem textDocument: The text document. +        :param Position position: The position inside the text document.. +        """ +        result_dict = self.lsp_endpoint.call_method( +            "textDocument/signatureHelp", textDocument=textDocument, position=position +        ) +        return SignatureHelp(**result_dict) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py new file mode 100644 index 00000000..14d2ca07 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py @@ -0,0 +1,71 @@ +from __future__ import print_function + +import threading + + +class LspEndpoint(threading.Thread): +    def __init__(self, json_rpc_endpoint, default_callback=print, callbacks={}): +        threading.Thread.__init__(self) +        self.json_rpc_endpoint = json_rpc_endpoint +        self.callbacks = callbacks +        self.default_callback = default_callback +        self.event_dict = {} +        self.response_dict = {} +        self.next_id = 0 +        # self.daemon = True +        self.shutdown_flag = False + +    def handle_result(self, jsonrpc_res): +        self.response_dict[jsonrpc_res["id"]] = jsonrpc_res +        cond = self.event_dict[jsonrpc_res["id"]] +        cond.acquire() +        cond.notify() +        cond.release() + +    def stop(self): +        self.shutdown_flag = True + +    def run(self): +        while not self.shutdown_flag: +            jsonrpc_message = self.json_rpc_endpoint.recv_response() + +            if jsonrpc_message is None: +                print("server quit") +                break + +            # print("recieved message:", jsonrpc_message) +            if "result" in jsonrpc_message or "error" in jsonrpc_message: +                self.handle_result(jsonrpc_message) +            elif "method" in jsonrpc_message: +                if jsonrpc_message["method"] in self.callbacks: +                    self.callbacks[jsonrpc_message["method"]](jsonrpc_message) +                else: +                    self.default_callback(jsonrpc_message) +            else: +                print("unknown jsonrpc message") +            # print(jsonrpc_message) + +    def send_message(self, method_name, params, id=None): +        message_dict = {} +        message_dict["jsonrpc"] = "2.0" +        if id is not None: +            message_dict["id"] = id +        message_dict["method"] = method_name +        message_dict["params"] = params +        self.json_rpc_endpoint.send_request(message_dict) + +    def call_method(self, method_name, **kwargs): +        current_id = self.next_id +        self.next_id += 1 +        cond = threading.Condition() +        self.event_dict[current_id] = cond +        cond.acquire() +        self.send_message(method_name, kwargs, current_id) +        cond.wait() +        cond.release() +        # TODO: check if error, and throw an exception +        response = self.response_dict[current_id] +        return response["result"] + +    def send_notification(self, method_name, **kwargs): +        self.send_message(method_name, kwargs) diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_structs.py b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py new file mode 100644 index 00000000..2f0940d4 --- /dev/null +++ b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py @@ -0,0 +1,316 @@ +def to_type(o, new_type): +    ''' +    Helper funciton that receives an object or a dict and convert it to a new given type. + +    :param object|dict o: The object to convert +    :param Type new_type: The type to convert to. +    ''' +    if new_type == type(o): +        return o +    else: +        return new_type(**o) + + +class Position(object): +    def __init__(self, line, character): +        """ +        Constructs a new Position instance. + +        :param int line: Line position in a document (zero-based). +        :param int character: Character offset on a line in a document (zero-based). +        """ +        self.line = line +        self.character = character + + +class Range(object): +    def __init__(self, start, end): +        """ +        Constructs a new Range instance. + +        :param Position start: The range's start position. +        :param Position end: The range's end position. +        """ +        self.start = to_type(start, Position) +        self.end = to_type(end, Position) + + +class Location(object): +    """ +    Represents a location inside a resource, such as a line inside a text file. +    """ +    def __init__(self, uri, range): +        """ +        Constructs a new Range instance. + +        :param str uri: Resource file. +        :param Range range: The range inside the file +        """ +        self.uri = uri +        self.range = to_type(range, Range) + +  +class Diagnostic(object): +     def __init__(self, range, severity, code, source, message, relatedInformation): +        """ +        Constructs a new Diagnostic instance. +        :param Range range: The range at which the message applies.Resource file. +        :param int severity: The diagnostic's severity. Can be omitted. If omitted it is up to the +                                client to interpret diagnostics as error, warning, info or hint. +        :param str code: The diagnostic's code, which might appear in the user interface. +        :param str source: A human-readable string describing the source of this +                            diagnostic, e.g. 'typescript' or 'super lint'. +        :param str message: The diagnostic's message. +        :param list relatedInformation: An array of related diagnostic information, e.g. when symbol-names within    +                                        a scope collide all definitions can be marked via this property. +        """ +        self.range = range +        self.severity = severity +        self.code = code +        self.source = source +        self.message = message +        self.relatedInformation = relatedInformation + + +class DiagnosticSeverity(object): +    Error = 1 +    Warning = 2 # TODO: warning is known in python +    Information = 3 +    Hint = 4 + + +class DiagnosticRelatedInformation(object): +    def __init__(self, location, message): +        """ +        Constructs a new Diagnostic instance. +        :param Location location: The location of this related diagnostic information. +        :param str message: The message of this related diagnostic information. +        """ +        self.location = location +        self.message = message + +  +class Command(object): +     def __init__(self, title, command, arguments): +        """ +        Constructs a new Diagnostic instance. +        :param str title: Title of the command, like `save`. +        :param str command: The identifier of the actual command handler. +        :param list argusments: Arguments that the command handler should be invoked with. +        """ +        self.title = title +        self.command = command +        self.arguments = arguments + + +class TextDocumentItem(object): +    """ +    An item to transfer a text document from the client to the server. +    """ +    def __init__(self, uri, languageId, version, text): +        """ +        Constructs a new Diagnostic instance. +         +        :param DocumentUri uri: Title of the command, like `save`. +        :param str languageId: The identifier of the actual command handler. +        :param int version: Arguments that the command handler should be invoked with. +        :param str text: Arguments that the command handler should be invoked with. +        """ +        self.uri = uri +        self.languageId = languageId +        self.version = version +        self.text = text + + +class TextDocumentIdentifier(object): +    """ +    Text documents are identified using a URI. On the protocol level, URIs are passed as strings.  +    """ +    def __init__(self, uri): +        """ +        Constructs a new TextDocumentIdentifier instance. + +        :param DocumentUri uri: The text document's URI.        +        """ +        self.uri = uri + +class TextDocumentPositionParams(object): +    """ +    A parameter literal used in requests to pass a text document and a position inside that document. +    """ +    def __init__(self, textDocument, position): +        """ +        Constructs a new TextDocumentPositionParams instance. +         +        :param TextDocumentIdentifier textDocument: The text document. +        :param Position position: The position inside the text document. +        """ +        self.textDocument = textDocument +        self.position = position + + +class LANGUAGE_IDENTIFIER: +    BAT="bat" +    BIBTEX="bibtex" +    CLOJURE="clojure" +    COFFESCRIPT="coffeescript" +    C="c" +    CPP="cpp" +    CSHARP="csharp" +    CSS="css" +    DIFF="diff" +    DOCKERFILE="dockerfile" +    FSHARP="fsharp" +    GIT_COMMIT="git-commit" +    GIT_REBASE="git-rebase" +    GO="go" +    GROOVY="groovy" +    HANDLEBARS="handlebars" +    HTML="html" +    INI="ini" +    JAVA="java" +    JAVASCRIPT="javascript" +    JSON="json" +    LATEX="latex" +    LESS="less" +    LUA="lua" +    MAKEFILE="makefile" +    MARKDOWN="markdown" +    OBJECTIVE_C="objective-c" +    OBJECTIVE_CPP="objective-cpp" +    Perl="perl" +    PHP="php" +    POWERSHELL="powershell" +    PUG="jade" +    PYTHON="python" +    R="r" +    RAZOR="razor" +    RUBY="ruby" +    RUST="rust" +    SASS="sass" +    SCSS="scss" +    ShaderLab="shaderlab" +    SHELL_SCRIPT="shellscript" +    SQL="sql" +    SWIFT="swift" +    TYPE_SCRIPT="typescript" +    TEX="tex" +    VB="vb" +    XML="xml" +    XSL="xsl" +    YAML="yaml" + + +class SymbolKind(object): +    File = 1 +    Module = 2 +    Namespace = 3 +    Package = 4 +    Class = 5 +    Method = 6 +    Property = 7 +    Field = 8 +    Constructor = 9 +    Enum = 10 +    Interface = 11 +    Function = 12 +    Variable = 13 +    Constant = 14 +    String = 15 +    Number = 16 +    Boolean = 17 +    Array = 18 +    Object = 19 +    Key = 20 +    Null = 21 +    EnumMember = 22 +    Struct = 23 +    Event = 24 +    Operator = 25 +    TypeParameter = 26 + + +class SymbolInformation(object): +    """ +    Represents information about programming constructs like variables, classes, interfaces etc. +    """ +    def __init__(self, name, kind, location, containerName, deprecated=False): +        """ +        Constructs a new SymbolInformation instance. + +        :param str name: The name of this symbol. +        :param int kind: The kind of this symbol. +        :param bool Location: The location of this symbol. The location's range is used by a tool +                                to reveal the location in the editor. If the symbol is selected in the +                                tool the range's start information is used to position the cursor. So +                                the range usually spans more then the actual symbol's name and does +                                normally include things like visibility modifiers. + +                                The range doesn't have to denote a node range in the sense of a abstract +                                syntax tree. It can therefore not be used to re-construct a hierarchy of +                                the symbols. +        :param str containerName: The name of the symbol containing this symbol. This information is for +                                    user interface purposes (e.g. to render a qualifier in the user interface +                                    if necessary). It can't be used to re-infer a hierarchy for the document +                                    symbols. +        :param bool deprecated: Indicates if this symbol is deprecated. +        """ +        self.name = name +        self.kind = kind +        self.deprecated = deprecated +        self.location = to_type(location, Location) +        self.containerName = containerName + + +class ParameterInformation(object): +    """ +    Represents a parameter of a callable-signature. A parameter can +    have a label and a doc-comment. +    """ +    def __init__(self, label, documentation=""): +        """ +        Constructs a new ParameterInformation instance. + +        :param str label: The label of this parameter. Will be shown in the UI. +        :param str documentation: The human-readable doc-comment of this parameter. Will be shown in the UI but can be omitted. +        """ +        self.label = label +        self.documentation = documentation + + +class SignatureInformation(object): +    """ +    Represents the signature of something callable. A signature +    can have a label, like a function-name, a doc-comment, and +    a set of parameters. +    """ +    def __init__(self, label, documentation="", parameters=[]): +        """ +        Constructs a new SignatureInformation instance. + +        :param str label: The label of this signature. Will be shown in the UI. +        :param str documentation: The human-readable doc-comment of this signature. Will be shown in the UI but can be omitted. +        :param ParameterInformation[] parameters: The parameters of this signature. +        """ +        self.label = label +        self.documentation = documentation +        self.parameters = [to_type(parameter, ParameterInformation) for parameter in parameters] + + +class SignatureHelp(object): +    """ +    Signature help represents the signature of something +    callable. There can be multiple signature but only one +    active and only one active parameter. +    """ +    def __init__(self, signatures, activeSignature=0, activeParameter=0): +        """ +        Constructs a new SignatureHelp instance. + +        :param SignatureInformation[] signatures: One or more signatures. +        :param int activeSignature: +        :param int activeParameter: +        """ +        self.signatures = [to_type(signature, SignatureInformation) for signature in signatures] +        self.activeSignature = activeSignature +        self.activeParameter = activeParameter
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/filter_files.py b/continuedev/src/continuedev/libs/util/filter_files.py new file mode 100644 index 00000000..6ebaa274 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/filter_files.py @@ -0,0 +1,33 @@ +import fnmatch +from typing import List + +DEFAULT_IGNORE_DIRS = [ +    ".git", +    ".vscode", +    ".idea", +    ".vs", +    ".venv", +    "env", +    ".env", +    "node_modules", +    "dist", +    "build", +    "target", +    "out", +    "bin", +    ".pytest_cache", +    ".vscode-test", +    ".continue", +    "__pycache__", +] + +DEFAULT_IGNORE_PATTERNS = DEFAULT_IGNORE_DIRS + list( +    filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) +) + + +def should_filter_path( +    path: str, ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS +) -> bool: +    """Returns whether a file should be filtered""" +    return any(fnmatch.fnmatch(path, pattern) for pattern in ignore_patterns) diff --git a/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py new file mode 100644 index 00000000..58a4f728 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py @@ -0,0 +1,56 @@ +from boltons import tbutils + +from ....models.main import Traceback + +PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):" + + +def get_python_traceback(output: str) -> str: +    if PYTHON_TRACEBACK_PREFIX in output: +        tb_string = output.split(PYTHON_TRACEBACK_PREFIX)[-1] + +        # Then need to remove any lines below the traceback. Do this by noticing that +        # the last line of the traceback is the first (other than they prefix) that doesn't begin with whitespace +        lines = list(filter(lambda x: x.strip() != "", tb_string.splitlines())) +        for i in range(len(lines) - 1): +            if not lines[i].startswith(" "): +                tb_string = "\n".join(lines[: i + 1]) +                break + +        return PYTHON_TRACEBACK_PREFIX + "\n" + tb_string +    elif "SyntaxError" in output: +        return "SyntaxError" + output.split("SyntaxError")[-1] +    else: +        return None + + +def get_javascript_traceback(output: str) -> str: +    lines = output.splitlines() +    first_line = None +    for i in range(len(lines) - 1): +        segs = lines[i].split(":") +        if ( +            len(segs) > 1 +            and segs[0] != "" +            and segs[1].startswith(" ") +            and lines[i + 1].strip().startswith("at") +        ): +            first_line = lines[i] +            break + +    if first_line is not None: +        return "\n".join(lines[lines.index(first_line) :]) +    else: +        return None + + +def parse_python_traceback(tb_string: str) -> Traceback: +    # Remove anchor lines - tbutils doesn't always get them right +    tb_string = "\n".join( +        filter( +            lambda x: x.strip().replace("~", "").replace("^", "") != "", +            tb_string.splitlines(), +        ) +    ) +    exc = tbutils.ParsedException.from_string(tb_string) +    return Traceback.from_tbutil_parsed_exc(exc) diff --git a/continuedev/src/continuedev/libs/util/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback_parsers.py deleted file mode 100644 index 2b164c0f..00000000 --- a/continuedev/src/continuedev/libs/util/traceback_parsers.py +++ /dev/null @@ -1,30 +0,0 @@ -PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):" - - -def get_python_traceback(output: str) -> str: -    if PYTHON_TRACEBACK_PREFIX in output: -        return PYTHON_TRACEBACK_PREFIX + output.split(PYTHON_TRACEBACK_PREFIX)[-1] -    elif "SyntaxError" in output: -        return "SyntaxError" + output.split("SyntaxError")[-1] -    else: -        return None - - -def get_javascript_traceback(output: str) -> str: -    lines = output.splitlines() -    first_line = None -    for i in range(len(lines) - 1): -        segs = lines[i].split(":") -        if ( -            len(segs) > 1 -            and segs[0] != "" -            and segs[1].startswith(" ") -            and lines[i + 1].strip().startswith("at") -        ): -            first_line = lines[i] -            break - -    if first_line is not None: -        return "\n".join(lines[lines.index(first_line) :]) -    else: -        return None diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index de426282..514337bf 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -136,6 +136,11 @@ class FileSystem(AbstractModel):          """Apply edit to filesystem, calculate the reverse edit, and return and EditDiff"""          raise NotImplementedError +    @abstractmethod +    def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: +        """List the contents of a directory""" +        raise NotImplementedError +      @classmethod      def read_range_in_str(self, s: str, r: Range) -> str:          lines = s.split("\n")[r.start.line : r.end.line + 1] @@ -312,6 +317,18 @@ class RealFileSystem(FileSystem):          self.write(edit.filepath, new_content)          return diff +    def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: +        """List the contents of a directory""" +        if recursive: +            # Walk +            paths = [] +            for root, dirs, files in os.walk(path): +                for f in files: +                    paths.append(os.path.join(root, f)) + +            return paths +        return list(map(lambda x: os.path.join(path, x), os.listdir(path))) +  class VirtualFileSystem(FileSystem):      """A simulated filesystem from a mapping of filepath to file contents.""" @@ -363,5 +380,16 @@ class VirtualFileSystem(FileSystem):          self.write(edit.filepath, new_content)          return EditDiff(edit=edit, original=original) +    def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]: +        """List the contents of a directory""" +        if recursive: +            for filepath in self.files: +                if filepath.startswith(path): +                    yield filepath + +        for filepath in self.files: +            if filepath.startswith(path) and "/" not in filepath[len(path) :]: +                yield filepath +  # TODO: Uniform errors thrown by any FileSystem subclass. diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index d442a415..a8b4aab0 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -50,6 +50,11 @@ class Position(BaseModel):          return sum(map(len, lines[: self.line])) + self.character +class PositionInFile(BaseModel): +    position: Position +    filepath: str + +  class Range(BaseModel):      """A range in a file. 0-indexed.""" diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 9ac28aa2..06d22348 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -5,6 +5,7 @@ from typing import List  from ...core.context import ContextProvider  from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.sdk import ContinueSDK +from ...libs.util.filter_files import DEFAULT_IGNORE_PATTERNS  from .util import remove_meilisearch_disallowed_chars  MAX_SIZE_IN_CHARS = 25_000 @@ -18,36 +19,13 @@ async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str:          return None -DEFAULT_IGNORE_DIRS = [ -    ".git", -    ".vscode", -    ".idea", -    ".vs", -    ".venv", -    "env", -    ".env", -    "node_modules", -    "dist", -    "build", -    "target", -    "out", -    "bin", -    ".pytest_cache", -    ".vscode-test", -    ".continue", -    "__pycache__", -] - -  class FileContextProvider(ContextProvider):      """      The FileContextProvider is a ContextProvider that allows you to search files in the open workspace.      """      title = "file" -    ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + list( -        filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) -    ) +    ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS      async def start(self, *args):          await super().start(*args) diff --git a/continuedev/src/continuedev/plugins/policies/commit.py b/continuedev/src/continuedev/plugins/policies/commit.py new file mode 100644 index 00000000..2fa43676 --- /dev/null +++ b/continuedev/src/continuedev/plugins/policies/commit.py @@ -0,0 +1,77 @@ +# An agent that makes a full commit in the background +# Plans +# Write code +# Reviews code +# Cleans up + +# It's important that agents are configurable, because people need to be able to specify +# which hooks they want to run. Specific linter, run tests, etc. +# And all of this can be easily specified in the Policy. + + +from textwrap import dedent +from typing import Literal + +from ...core.config import ContinueConfig +from ...core.main import History, Policy, Step +from ...core.observation import TextObservation +from ...core.sdk import ContinueSDK + + +class PlanStep(Step): +    user_input: str + +    _prompt = dedent( +        """\ +        You were given the following instructions: "{user_input}". +             +        Create a plan for how you will complete the task. +             +        Here are relevant files: + +        {relevant_files} +             +        Your plan will include: +        1. A high-level description of how you are going to accomplish the task +        2. A list of which files you will edit +        3. A description of what you will change in each file +        """ +    ) + +    async def run(self, sdk: ContinueSDK): +        plan = await sdk.models.default.complete( +            self._prompt.format( +                {"user_input": self.user_input, "relevant_files": "TODO"} +            ) +        ) +        return TextObservation(text=plan) + + +class WriteCommitStep(Step): +    async def run(self, sdk: ContinueSDK): +        pass + + +class ReviewCodeStep(Step): +    async def run(self, sdk: ContinueSDK): +        pass + + +class CleanupStep(Step): +    async def run(self, sdk: ContinueSDK): +        pass + + +class CommitPolicy(Policy): +    user_input: str + +    current_step: Literal["plan", "write", "review", "cleanup"] = "plan" + +    def next(self, config: ContinueConfig, history: History) -> Step: +        if history.get_current() is None: +            return ( +                PlanStep(user_input=self.user_input) +                >> WriteCommitStep() +                >> ReviewCodeStep() +                >> CleanupStep() +            ) diff --git a/continuedev/src/continuedev/plugins/policies/headless.py b/continuedev/src/continuedev/plugins/policies/headless.py new file mode 100644 index 00000000..56ebe31f --- /dev/null +++ b/continuedev/src/continuedev/plugins/policies/headless.py @@ -0,0 +1,18 @@ +from ...core.config import ContinueConfig +from ...core.main import History, Policy, Step +from ...core.observation import TextObservation +from ...plugins.steps.core.core import ShellCommandsStep +from ...plugins.steps.on_traceback import DefaultOnTracebackStep + + +class HeadlessPolicy(Policy): +    command: str + +    def next(self, config: ContinueConfig, history: History) -> Step: +        if history.get_current() is None: +            return ShellCommandsStep(cmds=[self.command]) +        observation = history.get_current().observation +        if isinstance(observation, TextObservation): +            return DefaultOnTracebackStep(output=observation.text) + +        return None diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5325a918..740c75bc 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,5 +1,6 @@  # These steps are depended upon by ContinueSDK  import difflib +import subprocess  import traceback  from textwrap import dedent  from typing import Any, Coroutine, List, Optional, Union @@ -112,52 +113,22 @@ class ShellCommandsStep(Step):          )      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd - -        for cmd in self.cmds: -            output = await sdk.ide.runCommand(cmd) -            if ( -                self.handle_error -                and output is not None -                and output_contains_error(output) -            ): -                suggestion = await sdk.models.medium.complete( -                    dedent( -                        f"""\ -                    While running the command `{cmd}`, the following error occurred: - -                    ```ascii -                    {output} -                    ``` - -                    This is a brief summary of the error followed by a suggestion on how it can be fixed:""" -                    ), -                    with_history=await sdk.get_chat_context(), -                ) - -                sdk.raise_exception( -                    title="Error while running query", -                    message=output, -                    with_step=MessageStep( -                        name=f"Suggestion to solve error {AI_ASSISTED_STRING}", -                        message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.", -                    ), -                ) - -        return TextObservation(text=output) - -        # process = subprocess.Popen( -        #     '/bin/bash', stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=cwd) +        process = subprocess.Popen( +            "/bin/bash", +            stdin=subprocess.PIPE, +            stdout=subprocess.PIPE, +            cwd=self.cwd or sdk.ide.workspace_directory, +        ) -        # stdin_input = "\n".join(self.cmds) -        # out, err = process.communicate(stdin_input.encode()) +        stdin_input = "\n".join(self.cmds) +        out, err = process.communicate(stdin_input.encode()) -        # # If it fails, return the error -        # if err is not None and err != "": -        #     self._err_text = err -        #     return TextObservation(text=err) +        # If it fails, return the error +        if err is not None and err != "": +            self._err_text = err +            return TextObservation(text=err) -        # return None +        return None  class DefaultModelEditCodeStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py index 1d135b3e..7ceefe9b 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py +++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py @@ -8,6 +8,10 @@ class ImplementAbstractMethodStep(Step):      class_name: str      async def run(self, sdk: ContinueSDK): +        if sdk.lsp is None: +            self.description = "Language Server Protocol is not enabled" +            return +          implementations = await sdk.lsp.go_to_implementations(self.class_name)          for implementation in implementations: diff --git a/continuedev/src/continuedev/plugins/steps/on_traceback.py b/continuedev/src/continuedev/plugins/steps/on_traceback.py index 63bae805..6b75d726 100644 --- a/continuedev/src/continuedev/plugins/steps/on_traceback.py +++ b/continuedev/src/continuedev/plugins/steps/on_traceback.py @@ -1,15 +1,28 @@  import os +from textwrap import dedent +from typing import Dict, List, Optional, Tuple  from ...core.main import ChatMessage, Step  from ...core.sdk import ContinueSDK -from ...libs.util.traceback_parsers import ( +from ...libs.util.filter_files import should_filter_path +from ...libs.util.traceback.traceback_parsers import (      get_javascript_traceback,      get_python_traceback, +    parse_python_traceback,  ) +from ...models.filesystem import RangeInFile +from ...models.main import Range, Traceback, TracebackFrame  from .chat import SimpleChatStep  from .core.core import UserInputStep +def extract_traceback_str(output: str) -> str: +    tb = output.strip() +    for tb_parser in [get_python_traceback, get_javascript_traceback]: +        if parsed_tb := tb_parser(tb): +            return parsed_tb + +  class DefaultOnTracebackStep(Step):      output: str      name: str = "Help With Traceback" @@ -38,11 +51,11 @@ class DefaultOnTracebackStep(Step):          # And this function is where you can get arbitrarily fancy about adding context      async def run(self, sdk: ContinueSDK): -        tb = self.output.strip() -        for tb_parser in [get_python_traceback, get_javascript_traceback]: -            if parsed_tb := tb_parser(tb): -                tb = parsed_tb -                break +        if get_python_traceback(self.output) is not None: +            await sdk.run_step(SolvePythonTracebackStep(output=self.output)) +            return + +        tb = extract_traceback_str(self.output)          tb_first_last_lines = (              ("\n".join(tb.split("\n")[:3]) + "\n...\n" + "\n".join(tb.split("\n")[-3:])) @@ -57,3 +70,139 @@ class DefaultOnTracebackStep(Step):              )          )          await sdk.run_step(SimpleChatStep(name="Help With Traceback")) + + +def filter_frames(frames: List[TracebackFrame]) -> List[TracebackFrame]: +    """Filter out frames that are not relevant to the user's code.""" +    return list(filter(lambda x: should_filter_path(x.filepath), frames)) + + +def find_external_call( +    frames: List[TracebackFrame], +) -> Optional[Tuple[TracebackFrame, TracebackFrame]]: +    """Moving up from the bottom of the stack, if the frames are not user code, then find the last frame before it becomes user code.""" +    if not should_filter_path(frames[-1].filepath): +        # No external call, error comes directly from user code +        return None + +    for i in range(len(frames) - 2, -1, -1): +        if not should_filter_path(frames[i].filepath): +            return frames[i], frames[i + 1] + + +def get_func_source_for_frame(frame: Dict) -> str: +    """Get the source for the function called in the frame.""" +    pass + + +async def fetch_docs_for_external_call(external_call: Dict, next_frame: Dict) -> str: +    """Fetch docs for the external call.""" +    pass + + +class SolvePythonTracebackStep(Step): +    output: str +    name: str = "Solve Traceback" +    hide: bool = True + +    async def external_call_prompt( +        self, sdk: ContinueSDK, external_call: Tuple[Dict, Dict], tb_string: str +    ) -> str: +        external_call, next_frame = external_call +        source_line = external_call["source_line"] +        external_func_source = get_func_source_for_frame(next_frame) +        docs = await fetch_docs_for_external_call(external_call, next_frame) + +        prompt = dedent( +            f"""\ +                    I got the following error: +                 +                    {tb_string} + +                    I tried to call an external library like this: + +                    ```python +                    {source_line} +                    ``` + +                    This is the definition of the function I tried to call: + +                    ```python +                    {external_func_source} +                    ``` +                 +                    Here's the documentation for the external library I tried to call: +                 +                    {docs} + +                    Explain how to fix the error. +                    """ +        ) + +        return prompt + +    async def normal_traceback_prompt( +        self, sdk: ContinueSDK, tb: Traceback, tb_string: str +    ) -> str: +        function_bodies = await get_functions_from_traceback(tb, sdk) + +        prompt = ( +            "Here are the functions from the traceback (most recent call last):\n\n" +        ) +        for i, function_body in enumerate(function_bodies): +            prompt += f'File "{tb.frames[i].filepath}", line {tb.frames[i].lineno}, in {tb.frames[i].function}\n\n```python\n{function_body or tb.frames[i].code}\n```\n\n' + +        prompt += ( +            "Here is the traceback:\n\n```\n" +            + tb_string +            + "\n```\n\nExplain how to fix the error." +        ) + +        return prompt + +    async def run(self, sdk: ContinueSDK): +        tb_string = get_python_traceback(self.output) +        tb = parse_python_traceback(tb_string) + +        if external_call := find_external_call(tb.frames): +            prompt = await self.external_call_prompt(sdk, external_call, tb_string) +        else: +            prompt = await self.normal_traceback_prompt(sdk, tb, tb_string) + +        await sdk.run_step( +            UserInputStep( +                description="Solving stack trace", +                user_input=prompt, +            ) +        ) +        await sdk.run_step(SimpleChatStep(name="Help With Traceback")) + + +async def get_function_body(frame: TracebackFrame, sdk: ContinueSDK) -> Optional[str]: +    """Get the function body from the traceback frame.""" +    if sdk.lsp is None: +        return None + +    document_symbols = sdk.lsp.get_symbols(frame.filepath) +    for symbol in document_symbols: +        if symbol.name == frame.function: +            r = symbol.location.range +            return await sdk.ide.readRangeInFile( +                RangeInFile( +                    filepath=frame.filepath, +                    range=Range.from_shorthand( +                        r.start.line, r.start.character, r.end.line, r.end.character +                    ), +                ) +            ) +    return None + + +async def get_functions_from_traceback(tb: Traceback, sdk: ContinueSDK) -> List[str]: +    """Get the function bodies from the traceback.""" +    function_bodies = [] +    for frame in tb.frames: +        if frame.function: +            function_bodies.append(await get_function_body(frame, sdk)) + +    return function_bodies diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 11099494..6aae8cc5 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -117,10 +117,15 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool:          await asyncio.sleep(frequency) +meilisearch_process = None + +  async def start_meilisearch():      """      Starts the MeiliSearch server, wait for it.      """ +    global meilisearch_process +      serverPath = getServerFolderPath()      # Check if MeiliSearch is installed, if not download @@ -130,7 +135,7 @@ async def start_meilisearch():      if not await check_meilisearch_running() or not was_already_installed:          logger.debug("Starting MeiliSearch...")          binary_name = "meilisearch" if os.name == "nt" else "./meilisearch" -        subprocess.Popen( +        meilisearch_process = subprocess.Popen(              [binary_name, "--no-analytics"],              cwd=serverPath,              stdout=subprocess.DEVNULL, @@ -139,3 +144,14 @@ async def start_meilisearch():              start_new_session=True,              shell=True,          ) + + +def stop_meilisearch(): +    """ +    Stops the MeiliSearch server. +    """ +    global meilisearch_process +    if meilisearch_process is not None: +        meilisearch_process.terminate() +        meilisearch_process.wait() +        meilisearch_process = None diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index f0c33929..f0080104 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, WebSocket  from fastapi.websockets import WebSocketState  from ..core.autopilot import Autopilot +from ..core.config import ContinueConfig  from ..core.main import FullState  from ..libs.util.create_async_task import create_async_task  from ..libs.util.logging import logger @@ -57,7 +58,10 @@ class SessionManager:          return self.sessions[session_id]      async def new_session( -        self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None +        self, +        ide: AbstractIdeProtocolServer, +        session_id: Optional[str] = None, +        config: Optional[ContinueConfig] = None,      ) -> Session:          logger.debug(f"New session: {session_id}") @@ -85,7 +89,7 @@ class SessionManager:          # Start the autopilot (must be after session is added to sessions) and the policy          try: -            await autopilot.start(full_state=full_state) +            await autopilot.start(full_state=full_state, config=config)          except Exception as e:              await ide.on_error(e) diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py index 91ddd33f..f4aea1fb 100644 --- a/continuedev/src/continuedev/tests/llm_test.py +++ b/continuedev/src/continuedev/tests/llm_test.py @@ -13,10 +13,11 @@ from continuedev.libs.llm.openai import OpenAI  from continuedev.libs.llm.together import TogetherLLM  from continuedev.libs.util.count_tokens import DEFAULT_ARGS  from continuedev.tests.util.openai_mock import start_openai +from continuedev.tests.util.prompts import tokyo_test_pair  load_dotenv() -TEST_PROMPT = "Output a single word, that being the capital of Japan:" +  SPEND_MONEY = True @@ -65,9 +66,9 @@ class TestBaseLLM:          if self.llm.__class__.__name__ == "LLM":              pytest.skip("Skipping abstract LLM") -        resp = await self.llm.complete(TEST_PROMPT, temperature=0.0) +        resp = await self.llm.complete(tokyo_test_pair[0], temperature=0.0)          assert isinstance(resp, str) -        assert resp.strip().lower() == "tokyo" +        assert resp.strip().lower() == tokyo_test_pair[1]      @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")      @async_test @@ -79,7 +80,9 @@ class TestBaseLLM:          role = None          async for chunk in self.llm.stream_chat(              messages=[ -                ChatMessage(role="user", content=TEST_PROMPT, summary=TEST_PROMPT) +                ChatMessage( +                    role="user", content=tokyo_test_pair[0], summary=tokyo_test_pair[0] +                )              ],              temperature=0.0,          ): @@ -90,7 +93,7 @@ class TestBaseLLM:                  role = chunk["role"]          assert role == "assistant" -        assert completion.strip().lower() == "tokyo" +        assert completion.strip().lower() == tokyo_test_pair[1]      @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")      @async_test @@ -99,11 +102,13 @@ class TestBaseLLM:              pytest.skip("Skipping abstract LLM")          completion = "" -        async for chunk in self.llm.stream_complete(TEST_PROMPT, temperature=0.0): +        async for chunk in self.llm.stream_complete( +            tokyo_test_pair[0], temperature=0.0 +        ):              assert isinstance(chunk, str)              completion += chunk -        assert completion.strip().lower() == "tokyo" +        assert completion.strip().lower() == tokyo_test_pair[1]  class TestOpenAI(TestBaseLLM): @@ -129,7 +134,7 @@ class TestOpenAI(TestBaseLLM):              "Output a single word, that being the capital of Japan:"          )          assert isinstance(resp, str) -        assert resp.strip().lower() == "tokyo" +        assert resp.strip().lower() == tokyo_test_pair[1]  class TestGGML(TestBaseLLM): diff --git a/continuedev/src/continuedev/tests/step_test.py b/continuedev/src/continuedev/tests/step_test.py new file mode 100644 index 00000000..a4131e61 --- /dev/null +++ b/continuedev/src/continuedev/tests/step_test.py @@ -0,0 +1,43 @@ +import pytest + +from continuedev.core.config import ContinueConfig +from continuedev.headless import start_headless_session +from continuedev.plugins.steps.core.core import UserInputStep +from continuedev.tests.util.prompts import tokyo_test_pair + +TEST_CONFIG = ContinueConfig() + + +@pytest.mark.asyncio +async def test_step(): +    session = await start_headless_session(config=TEST_CONFIG) + +    await session.autopilot.run_from_step(UserInputStep(user_input=tokyo_test_pair[0])) + +    full_state = await session.autopilot.get_full_state() +    assert ( +        full_state.history.timeline[-1].step.description.strip().lower() +        == tokyo_test_pair[1] +    ) + +    await session.autopilot.cleanup() + + +# TODO: Test other properties of full_state after the UserInputStep. Also test with other config properties and models, etc... +# so we are sure that UserInputStep works in many cases. One example of a thing to check is that the step following UserInputStep +# is a SimpleChatStep, and that it is not hidden, and that the properties of the node (full_state.history.timeline[-1]) all look good +# (basically, run it, see what you get, then codify this in assertions) + +# TODO: Write tests for other steps: +# - DefaultOnTracebackStep +# - EditHighlightedCodeStep (note that this won't test the rendering in IDE) + +# NOTE: Avoid expensive prompts - not too big a deal, but don't have it generate an entire 100 line file +# If you want to not have llm_test.py spend money, change SPEND_MONEY to False at the +# top of that file, but make sure to put it back to True before committing + +# NOTE: Headless mode uses continuedev.src.continuedev.headless.headless_ide instead of +# VS Code, so many of the methods just pass, or might not act exactly how you expect. +# See the file for reference + +# NOTE: If this is too short or pointless a task, let me know and I'll set up testing of ContextProviders diff --git a/continuedev/src/continuedev/tests/util/config.py b/continuedev/src/continuedev/tests/util/config.py new file mode 100644 index 00000000..73d3aeff --- /dev/null +++ b/continuedev/src/continuedev/tests/util/config.py @@ -0,0 +1,19 @@ +from continuedev.src.continuedev.core.config import ContinueConfig +from continuedev.src.continuedev.core.models import Models +from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI + +config = ContinueConfig( +    allow_anonymous_telemetry=False, +    models=Models( +        default=MaybeProxyOpenAI(api_key="", model="gpt-4"), +        medium=MaybeProxyOpenAI( +            api_key="", +            model="gpt-3.5-turbo", +        ), +    ), +    system_message=None, +    temperature=0.5, +    custom_commands=[], +    slash_commands=[], +    context_providers=[], +) diff --git a/continuedev/src/continuedev/tests/util/prompts.py b/continuedev/src/continuedev/tests/util/prompts.py new file mode 100644 index 00000000..f60c926c --- /dev/null +++ b/continuedev/src/continuedev/tests/util/prompts.py @@ -0,0 +1 @@ +tokyo_test_pair = ("Output a single word, that being the capital of Japan:", "tokyo")
\ No newline at end of file @@ -12,8 +12,8 @@ a = Analysis(      datas=[          ('continuedev', 'continuedev'),          (certifi.where(), 'ca_bundle') -        ] + copy_metadata('replicate'), -    hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'replicate'], +        ], +    hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'python-lsp-server', 'replicate'] + copy_metadata('replicate'),      hookspath=[],      hooksconfig={},      runtime_hooks=[], diff --git a/test.py b/test.py new file mode 100644 index 00000000..5bd57e0e --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import unittest + + +def sort_numbers(numbers): +    for i in range(len(numbers)): +        for j in range(i + 1, len(numbers)): +            if numbers[i] > numbers[j]: +                numbers[i], numbers[j] = numbers[j], numbers[i] +    return numbers[:-1]  # Error here: We're not returning the last number + + +class TestSortNumbers(unittest.TestCase): +    def test_sort_numbers(self): +        self.assertEqual(sort_numbers([3, 2, 1]), [1, 2, 3])  # This test will fail +        self.assertEqual( +            sort_numbers([4, 2, 5, 1, 3]), [1, 2, 3, 4, 5] +        )  # This test will fail + + +if __name__ == "__main__": +    unittest.main() | 
