summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.vscode/launch.json11
-rw-r--r--continuedev/requirements.txt3
-rw-r--r--continuedev/src/continuedev/core/autopilot.py20
-rw-r--r--continuedev/src/continuedev/core/config.py11
-rw-r--r--continuedev/src/continuedev/core/lsp.py310
-rw-r--r--continuedev/src/continuedev/core/sdk.py51
-rw-r--r--continuedev/src/continuedev/headless/__init__.py41
-rw-r--r--continuedev/src/continuedev/headless/headless_ide.py181
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py9
-rw-r--r--continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py82
-rw-r--r--continuedev/src/continuedev/libs/lspclient/lsp_client.py150
-rw-r--r--continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py71
-rw-r--r--continuedev/src/continuedev/libs/lspclient/lsp_structs.py316
-rw-r--r--continuedev/src/continuedev/libs/util/filter_files.py33
-rw-r--r--continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py56
-rw-r--r--continuedev/src/continuedev/libs/util/traceback_parsers.py30
-rw-r--r--continuedev/src/continuedev/models/filesystem.py28
-rw-r--r--continuedev/src/continuedev/models/main.py5
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py26
-rw-r--r--continuedev/src/continuedev/plugins/policies/commit.py77
-rw-r--r--continuedev/src/continuedev/plugins/policies/headless.py18
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py57
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/abstract_method.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/on_traceback.py161
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py18
-rw-r--r--continuedev/src/continuedev/server/session_manager.py8
-rw-r--r--continuedev/src/continuedev/tests/llm_test.py21
-rw-r--r--continuedev/src/continuedev/tests/step_test.py43
-rw-r--r--continuedev/src/continuedev/tests/util/config.py19
-rw-r--r--continuedev/src/continuedev/tests/util/prompts.py1
-rw-r--r--run.spec4
-rw-r--r--test.py21
32 files changed, 1750 insertions, 136 deletions
diff --git a/.vscode/launch.json b/.vscode/launch.json
index 08061d13..9ccf4ce7 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -37,6 +37,17 @@
// What about a watch task? - type errors?
},
{
+ "name": "Headless",
+ "type": "python",
+ "request": "launch",
+ "module": "continuedev.headless",
+ "args": ["--config", "continuedev/config.py"],
+ "justMyCode": false,
+ "subProcess": false
+ // Does it need a build task?
+ // What about a watch task? - type errors?
+ },
+ {
"name": "Extension (VSCode)",
"type": "extensionHost",
"request": "launch",
diff --git a/continuedev/requirements.txt b/continuedev/requirements.txt
index 0cf909d5..91120c59 100644
--- a/continuedev/requirements.txt
+++ b/continuedev/requirements.txt
@@ -22,4 +22,5 @@ socksio==1.0.0
ripgrepy==2.0.0
replicate==0.11.0
bs4==0.0.1
-redbaron==0.9.2 \ No newline at end of file
+redbaron==0.9.2
+python-lsp-server==1.2.0 \ No newline at end of file
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index bae82739..8ac7241d 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -17,7 +17,10 @@ from ..libs.util.paths import getSavedContextGroupsPath
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.strings import remove_quotes_and_escapes
from ..libs.util.telemetry import posthog_logger
-from ..libs.util.traceback_parsers import get_javascript_traceback, get_python_traceback
+from ..libs.util.traceback.traceback_parsers import (
+ get_javascript_traceback,
+ get_python_traceback,
+)
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from ..models.main import ContinueBaseModel
@@ -32,6 +35,8 @@ from ..plugins.steps.core.core import (
)
from ..plugins.steps.on_traceback import DefaultOnTracebackStep
from ..server.ide_protocol import AbstractIdeProtocolServer
+from ..server.meilisearch_server import stop_meilisearch
+from .config import ContinueConfig
from .context import ContextManager
from .main import (
Context,
@@ -97,8 +102,12 @@ class Autopilot(ContinueBaseModel):
started: bool = False
- async def start(self, full_state: Optional[FullState] = None):
- self.continue_sdk = await ContinueSDK.create(self)
+ async def start(
+ self,
+ full_state: Optional[FullState] = None,
+ config: Optional[ContinueConfig] = None,
+ ):
+ self.continue_sdk = await ContinueSDK.create(self, config=config)
if override_policy := self.continue_sdk.config.policy_override:
self.policy = override_policy
@@ -134,6 +143,11 @@ class Autopilot(ContinueBaseModel):
self.started = True
+ async def cleanup(self):
+ if self.continue_sdk.lsp is not None:
+ await self.continue_sdk.lsp.stop()
+ stop_meilisearch()
+
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 62e9c690..68b2b17d 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -59,3 +59,14 @@ class ContinueConfig(BaseModel):
@validator("temperature", pre=True)
def temperature_validator(cls, v):
return max(0.0, min(1.0, v))
+
+ @staticmethod
+ def from_filepath(filepath: str) -> "ContinueConfig":
+ # Use importlib to load the config file config.py at the given path
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location("config", filepath)
+ config = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(config)
+
+ return config.config
diff --git a/continuedev/src/continuedev/core/lsp.py b/continuedev/src/continuedev/core/lsp.py
new file mode 100644
index 00000000..5c1f9989
--- /dev/null
+++ b/continuedev/src/continuedev/core/lsp.py
@@ -0,0 +1,310 @@
+import os
+import socket
+import subprocess
+import threading
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+from ..libs.lspclient.json_rpc_endpoint import JsonRpcEndpoint
+from ..libs.lspclient.lsp_client import LspClient
+from ..libs.lspclient.lsp_endpoint import LspEndpoint
+from ..libs.lspclient.lsp_structs import Position as LspPosition
+from ..libs.lspclient.lsp_structs import SymbolInformation, TextDocumentIdentifier
+from ..models.filesystem import RangeInFile
+from ..models.main import Position, Range
+
+
+class ReadPipe(threading.Thread):
+ def __init__(self, pipe):
+ threading.Thread.__init__(self)
+ self.pipe = pipe
+
+ def run(self):
+ line = self.pipe.readline().decode("utf-8")
+ while line:
+ print(line)
+ line = self.pipe.readline().decode("utf-8")
+
+
+class SocketFileWrapper:
+ def __init__(self, sockfile):
+ self.sockfile = sockfile
+
+ def write(self, data):
+ if isinstance(data, bytes):
+ data = data.decode("utf-8").replace("\r\n", "\n")
+ return self.sockfile.write(data)
+
+ def read(self, size=-1):
+ data = self.sockfile.read(size)
+ if isinstance(data, str):
+ data = data.replace("\n", "\r\n").encode("utf-8")
+ return data
+
+ def readline(self, size=-1):
+ data = self.sockfile.readline(size)
+ if isinstance(data, str):
+ data = data.replace("\n", "\r\n").encode("utf-8")
+ return data
+
+ def flush(self):
+ return self.sockfile.flush()
+
+ def close(self):
+ return self.sockfile.close()
+
+
+def create_json_rpc_endpoint(use_subprocess: Optional[str] = None):
+ if use_subprocess is None:
+ # Connect to the server
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.connect(("localhost", 8080))
+
+ # Create a file-like object from the socket
+ sockfile = s.makefile("rw")
+ wrapped_sockfile = SocketFileWrapper(sockfile)
+ return JsonRpcEndpoint(wrapped_sockfile, wrapped_sockfile), None
+
+ else:
+ pyls_cmd = use_subprocess.split()
+ p = subprocess.Popen(
+ pyls_cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ read_pipe = ReadPipe(p.stderr)
+ read_pipe.start()
+ return JsonRpcEndpoint(p.stdin, p.stdout), p
+
+
+def filename_to_uri(filename: str) -> str:
+ return f"file://{filename}"
+
+
+def uri_to_filename(uri: str) -> str:
+ if uri.startswith("file://"):
+ return uri.lstrip("file://")
+ else:
+ return uri
+
+
+def create_lsp_client(workspace_dir: str, use_subprocess: Optional[str] = None):
+ json_rpc_endpoint, process = create_json_rpc_endpoint(use_subprocess=use_subprocess)
+ lsp_endpoint = LspEndpoint(json_rpc_endpoint)
+ lsp_client = LspClient(lsp_endpoint)
+ capabilities = {
+ "textDocument": {
+ "codeAction": {"dynamicRegistration": True},
+ "codeLens": {"dynamicRegistration": True},
+ "colorProvider": {"dynamicRegistration": True},
+ "completion": {
+ "completionItem": {
+ "commitCharactersSupport": True,
+ "documentationFormat": ["markdown", "plaintext"],
+ "snippetSupport": True,
+ },
+ "completionItemKind": {
+ "valueSet": [
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ ]
+ },
+ "contextSupport": True,
+ "dynamicRegistration": True,
+ },
+ "definition": {"dynamicRegistration": True},
+ "documentHighlight": {"dynamicRegistration": True},
+ "documentLink": {"dynamicRegistration": True},
+ "documentSymbol": {
+ "dynamicRegistration": True,
+ "symbolKind": {
+ "valueSet": [
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ ]
+ },
+ },
+ "formatting": {"dynamicRegistration": True},
+ "hover": {
+ "contentFormat": ["markdown", "plaintext"],
+ "dynamicRegistration": True,
+ },
+ "implementation": {"dynamicRegistration": True},
+ "onTypeFormatting": {"dynamicRegistration": True},
+ "publishDiagnostics": {"relatedInformation": True},
+ "rangeFormatting": {"dynamicRegistration": True},
+ "references": {"dynamicRegistration": True},
+ "rename": {"dynamicRegistration": True},
+ "signatureHelp": {
+ "dynamicRegistration": True,
+ "signatureInformation": {
+ "documentationFormat": ["markdown", "plaintext"]
+ },
+ },
+ "synchronization": {
+ "didSave": True,
+ "dynamicRegistration": True,
+ "willSave": True,
+ "willSaveWaitUntil": True,
+ },
+ "typeDefinition": {"dynamicRegistration": True},
+ },
+ "workspace": {
+ "applyEdit": True,
+ "configuration": True,
+ "didChangeConfiguration": {"dynamicRegistration": True},
+ "didChangeWatchedFiles": {"dynamicRegistration": True},
+ "executeCommand": {"dynamicRegistration": True},
+ "symbol": {
+ "dynamicRegistration": True,
+ "symbolKind": {
+ "valueSet": [
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ ]
+ },
+ },
+ "workspaceEdit": {"documentChanges": True},
+ "workspaceFolders": True,
+ },
+ }
+ root_uri = filename_to_uri(workspace_dir)
+ dir_name = os.path.basename(workspace_dir)
+ workspace_folders = [{"name": dir_name, "uri": root_uri}]
+ lsp_client.initialize(
+ None,
+ None,
+ root_uri,
+ None,
+ capabilities,
+ "off",
+ workspace_folders,
+ )
+ lsp_client.initialized()
+ return lsp_client, process
+
+
+class ContinueLSPClient(BaseModel):
+ workspace_dir: str
+ lsp_client: LspClient = None
+ use_subprocess: Optional[str] = None
+ lsp_process: Optional[subprocess.Popen] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("lsp_client", None)
+ return original_dict
+
+ async def start(self):
+ self.lsp_client, self.lsp_process = create_lsp_client(
+ self.workspace_dir, use_subprocess=self.use_subprocess
+ )
+
+ async def stop(self):
+ self.lsp_client.shutdown()
+ self.lsp_client.exit()
+ if self.lsp_process is not None:
+ self.lsp_process.terminate()
+ self.lsp_process.wait()
+ self.lsp_process = None
+
+ def goto_definition(self, position: Position, filename: str) -> List[RangeInFile]:
+ response = self.lsp_client.definition(
+ TextDocumentIdentifier(filename_to_uri(filename)),
+ LspPosition(position.line, position.character),
+ )
+ return [
+ RangeInFile(
+ filepath=uri_to_filename(x.uri),
+ range=Range.from_shorthand(
+ x.range.start.line,
+ x.range.start.character,
+ x.range.end.line,
+ x.range.end.character,
+ ),
+ )
+ for x in response
+ ]
+
+ def get_symbols(self, filepath: str) -> List[SymbolInformation]:
+ response = self.lsp_client.documentSymbol(
+ TextDocumentIdentifier(filename_to_uri(filepath))
+ )
+
+ return response
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 9ff6612c..7dca600d 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,8 +1,9 @@
import os
import traceback
-from typing import Coroutine, Union
+from typing import Coroutine, List, Optional, Union
from ..libs.llm import LLM
+from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
from ..libs.util.paths import getConfigFilePath
from ..libs.util.telemetry import posthog_logger
@@ -16,11 +17,18 @@ from ..models.filesystem_edit import (
FileSystemEdit,
)
from ..models.main import Range
-from ..plugins.steps.core.core import *
-from ..plugins.steps.core.core import DefaultModelEditCodeStep
+from ..plugins.steps.core.core import (
+ DefaultModelEditCodeStep,
+ FileSystemEditStep,
+ MessageStep,
+ RangeInFileWithContents,
+ ShellCommandsStep,
+ WaitForUserConfirmationStep,
+)
from ..server.ide_protocol import AbstractIdeProtocolServer
from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
+from .lsp import ContinueLSPClient
from .main import (
ChatMessage,
Context,
@@ -42,6 +50,7 @@ class ContinueSDK(AbstractContinueSDK):
ide: AbstractIdeProtocolServer
models: Models
+ lsp: Optional[ContinueLSPClient] = None
context: Context
config: ContinueConfig
__autopilot: Autopilot
@@ -52,13 +61,14 @@ class ContinueSDK(AbstractContinueSDK):
self.context = autopilot.context
@classmethod
- async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
+ async def create(
+ cls, autopilot: Autopilot, config: Optional[ContinueConfig] = None
+ ) -> "ContinueSDK":
sdk = ContinueSDK(autopilot)
autopilot.continue_sdk = sdk
try:
- config = sdk._load_config_dot_py()
- sdk.config = config
+ sdk.config = config or sdk._load_config_dot_py()
except Exception as e:
logger.error(f"Failed to load config.py: {traceback.format_exception(e)}")
@@ -78,9 +88,26 @@ class ContinueSDK(AbstractContinueSDK):
)
await sdk.ide.setFileOpen(getConfigFilePath())
+ # Start models
sdk.models = sdk.config.models
await sdk.models.start(sdk)
+ # Start LSP
+ async def start_lsp():
+ try:
+ sdk.lsp = ContinueLSPClient(
+ workspace_dir=sdk.ide.workspace_directory,
+ use_subprocess="python3.10 -m pylsp",
+ )
+ await sdk.lsp.start()
+ except:
+ logger.warning("Failed to start LSP client", exc_info=True)
+ sdk.lsp = None
+
+ create_async_task(
+ start_lsp(), on_error=lambda e: logger.error("Failed to setup LSP: %s", e)
+ )
+
# When the config is loaded, setup posthog logger
posthog_logger.setup(sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
@@ -207,19 +234,13 @@ class ContinueSDK(AbstractContinueSDK):
_last_valid_config: ContinueConfig = None
def _load_config_dot_py(self) -> ContinueConfig:
- # Use importlib to load the config file config.py at the given path
path = getConfigFilePath()
-
- import importlib.util
-
- spec = importlib.util.spec_from_file_location("config", path)
- config = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(config)
- self._last_valid_config = config.config
+ config = ContinueConfig.from_filepath(path)
+ self._last_valid_config = config
logger.debug("Loaded Continue config file from %s", path)
- return config.config
+ return config
def get_code_context(
self, only_editing: bool = False
diff --git a/continuedev/src/continuedev/headless/__init__.py b/continuedev/src/continuedev/headless/__init__.py
new file mode 100644
index 00000000..4e46409a
--- /dev/null
+++ b/continuedev/src/continuedev/headless/__init__.py
@@ -0,0 +1,41 @@
+import asyncio
+from typing import Optional, Union
+
+import typer
+
+from ..core.config import ContinueConfig
+from ..server.session_manager import Session, session_manager
+from .headless_ide import LocalIdeProtocol
+
+app = typer.Typer()
+
+
+async def start_headless_session(
+ config: Optional[Union[str, ContinueConfig]] = None
+) -> Session:
+ if config is not None:
+ if isinstance(config, str):
+ config: ContinueConfig = ContinueConfig.from_filepath(config)
+
+ ide = LocalIdeProtocol()
+ return await session_manager.new_session(ide, config=config)
+
+
+async def async_main(config: Optional[str] = None):
+ await start_headless_session(config=config)
+
+
+@app.command()
+def main(
+ config: Optional[str] = typer.Option(
+ None, help="The path to the configuration file"
+ )
+):
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(async_main(config))
+ tasks = asyncio.all_tasks(loop)
+ loop.run_until_complete(asyncio.gather(*tasks))
+
+
+if __name__ == "__main__":
+ app()
diff --git a/continuedev/src/continuedev/headless/headless_ide.py b/continuedev/src/continuedev/headless/headless_ide.py
new file mode 100644
index 00000000..088da2c9
--- /dev/null
+++ b/continuedev/src/continuedev/headless/headless_ide.py
@@ -0,0 +1,181 @@
+import os
+import subprocess
+import uuid
+from typing import Any, Callable, Coroutine, List, Optional
+
+from dotenv import load_dotenv
+from fastapi import WebSocket
+
+from ..models.filesystem import (
+ FileSystem,
+ RangeInFile,
+ RangeInFileWithContents,
+ RealFileSystem,
+)
+from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit
+from ..server.ide_protocol import AbstractIdeProtocolServer
+
+load_dotenv()
+
+
+def get_mac_address():
+ mac_num = hex(uuid.getnode()).replace("0x", "").upper()
+ mac = "-".join(mac_num[i : i + 2] for i in range(0, 11, 2))
+ return mac
+
+
+class LocalIdeProtocol(AbstractIdeProtocolServer):
+ websocket: WebSocket = None
+ session_id: Optional[str]
+ workspace_directory: str = os.getcwd()
+ unique_id: str = get_mac_address()
+
+ filesystem: FileSystem = RealFileSystem()
+
+ async def handle_json(self, data: Any):
+ """Handle a json message"""
+ pass
+
+ def showSuggestion(self, file_edit: FileEdit):
+ """Show a suggestion to the user"""
+ pass
+
+ async def setFileOpen(self, filepath: str, open: bool = True):
+ """Set whether a file is open"""
+ pass
+
+ async def showMessage(self, message: str):
+ """Show a message to the user"""
+ print(message)
+
+ async def showVirtualFile(self, name: str, contents: str):
+ """Show a virtual file"""
+ pass
+
+ async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
+ """Set whether suggestions are locked"""
+ pass
+
+ async def getSessionId(self):
+ """Get a new session ID"""
+ pass
+
+ async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
+ """Show suggestions to the user and wait for a response"""
+ pass
+
+ def onAcceptRejectSuggestion(self, accepted: bool):
+ """Called when the user accepts or rejects a suggestion"""
+ pass
+
+ def onFileSystemUpdate(self, update: FileSystemEdit):
+ """Called when a file system update is received"""
+ pass
+
+ def onCloseGUI(self, session_id: str):
+ """Called when a GUI is closed"""
+ pass
+
+ def onOpenGUIRequest(self):
+ """Called when a GUI is requested to be opened"""
+ pass
+
+ async def getOpenFiles(self) -> List[str]:
+ """Get a list of open files"""
+ pass
+
+ async def getVisibleFiles(self) -> List[str]:
+ """Get a list of visible files"""
+ pass
+
+ async def getHighlightedCode(self) -> List[RangeInFile]:
+ """Get a list of highlighted code"""
+ pass
+
+ async def readFile(self, filepath: str) -> str:
+ """Read a file"""
+ return self.filesystem.read(filepath)
+
+ async def readRangeInFile(self, range_in_file: RangeInFile) -> str:
+ """Read a range in a file"""
+ return self.filesystem.read_range_in_file(range_in_file)
+
+ async def editFile(self, edit: FileEdit):
+ """Edit a file"""
+ self.filesystem.apply_file_edit(edit)
+
+ async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff:
+ """Apply a file edit"""
+ return self.filesystem.apply_edit(edit)
+
+ async def saveFile(self, filepath: str):
+ """Save a file"""
+ pass
+
+ async def getUserSecret(self, key: str):
+ """Get a user secret"""
+ return os.environ.get(key)
+
+ async def highlightCode(self, range_in_file: RangeInFile, color: str):
+ """Highlight code"""
+ pass
+
+ async def runCommand(self, command: str) -> str:
+ """Run a command using subprocess (don't pass, actually implement)"""
+ return subprocess.check_output(command, shell=True).decode("utf-8")
+
+ def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
+ """Called when highlighted code is updated"""
+ pass
+
+ def onDeleteAtIndex(self, index: int):
+ """Called when a step is deleted at a given index"""
+ pass
+
+ async def showDiff(self, filepath: str, replacement: str, step_index: int):
+ """Show a diff"""
+ pass
+
+ def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]):
+ """Subscribe to files created event"""
+ pass
+
+ def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]):
+ """Subscribe to files deleted event"""
+ pass
+
+ def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]):
+ """Subscribe to files renamed event"""
+ pass
+
+ def subscribeToFileSaved(self, callback: Callable[[str, str], None]):
+ """Subscribe to file saved event"""
+ pass
+
+ def onFilesCreated(self, filepaths: List[str]):
+ """Called when files are created"""
+ pass
+
+ def onFilesDeleted(self, filepaths: List[str]):
+ """Called when files are deleted"""
+ pass
+
+ def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]):
+ """Called when files are renamed"""
+ pass
+
+ def onFileSaved(self, filepath: str, contents: str):
+ """Called when a file is saved"""
+ pass
+
+ async def fileExists(self, filepath: str) -> Coroutine[Any, Any, str]:
+ """Check if a file exists"""
+ return self.filesystem.exists(filepath)
+
+ async def getTerminalContents(self) -> Coroutine[Any, Any, str]:
+ return ""
+
+ async def listDirectoryContents(
+ self, directory: str, recursive: bool = False
+ ) -> List[str]:
+ return self.filesystem.list_directory_contents(directory, recursive=recursive)
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index 03c9cce4..1b91ec43 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -4,6 +4,7 @@ from typing import Callable, Optional
import aiohttp
from ..llm import LLM
+from ..util.logging import logger
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt
@@ -59,7 +60,13 @@ class TogetherLLM(LLM):
if chunk.strip() != "":
if chunk.startswith("data: "):
chunk = chunk[6:]
- json_chunk = json.loads(chunk)
+ if chunk == "[DONE]":
+ break
+ try:
+ json_chunk = json.loads(chunk)
+ except Exception as e:
+ logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}")
+ continue
if "choices" in json_chunk:
yield json_chunk["choices"][0]["text"]
diff --git a/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py
new file mode 100644
index 00000000..80c51000
--- /dev/null
+++ b/continuedev/src/continuedev/libs/lspclient/json_rpc_endpoint.py
@@ -0,0 +1,82 @@
+from __future__ import print_function
+
+import json
+import re
+import threading
+
+JSON_RPC_REQ_FORMAT = "Content-Length: {json_string_len}\r\n\r\n{json_string}"
+JSON_RPC_RES_REGEX = "Content-Length: ([0-9]*)\r\n"
+# TODO: add content-type
+
+
+class MyEncoder(json.JSONEncoder):
+ """
+ Encodes an object in JSON
+ """
+
+ def default(self, o):
+ return o.__dict__
+
+
+class JsonRpcEndpoint(object):
+ """
+ Thread safe JSON RPC endpoint implementation. Responsible to recieve and send JSON RPC messages, as described in the
+ protocol. More information can be found: https://www.jsonrpc.org/
+ """
+
+ def __init__(self, stdin, stdout):
+ self.stdin = stdin
+ self.stdout = stdout
+ self.read_lock = threading.Lock()
+ self.write_lock = threading.Lock()
+
+ @staticmethod
+ def __add_header(json_string):
+ """
+ Adds a header for the given json string
+
+ :param str json_string: The string
+ :return: the string with the header
+ """
+ return JSON_RPC_REQ_FORMAT.format(
+ json_string_len=len(json_string), json_string=json_string
+ )
+
+ def send_request(self, message):
+ """
+ Sends the given message.
+
+ :param dict message: The message to send.
+ """
+ json_string = json.dumps(message, cls=MyEncoder)
+ # print("sending:", json_string)
+ jsonrpc_req = self.__add_header(json_string)
+ with self.write_lock:
+ self.stdin.write(jsonrpc_req.encode())
+ self.stdin.flush()
+
+ def recv_response(self):
+ """
+ Recives a message.
+
+ :return: a message
+ """
+ with self.read_lock:
+ line = self.stdout.readline()
+ if not line:
+ return None
+ # print(line)
+ line = line.decode()
+ # TODO: handle content type as well.
+ match = re.match(JSON_RPC_RES_REGEX, line)
+ if match is None or not match.groups():
+ raise RuntimeError("Bad header: " + line)
+ size = int(match.groups()[0])
+ line = self.stdout.readline()
+ if not line:
+ return None
+ line = line.decode()
+ # if line != "\r\n":
+ # raise RuntimeError("Bad header: missing newline")
+ jsonrpc_res = self.stdout.read(size + 2)
+ return json.loads(jsonrpc_res)
diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_client.py b/continuedev/src/continuedev/libs/lspclient/lsp_client.py
new file mode 100644
index 00000000..fe2db6ad
--- /dev/null
+++ b/continuedev/src/continuedev/libs/lspclient/lsp_client.py
@@ -0,0 +1,150 @@
+from .lsp_structs import Location, SignatureHelp, SymbolInformation
+
+
+class LspClient(object):
+ def __init__(self, lsp_endpoint):
+ """
+ Constructs a new LspClient instance.
+
+ :param lsp_endpoint: TODO
+ """
+ self.lsp_endpoint = lsp_endpoint
+
+ def initialize(
+ self,
+ processId,
+ rootPath,
+ rootUri,
+ initializationOptions,
+ capabilities,
+ trace,
+ workspaceFolders,
+ ):
+ """
+ The initialize request is sent as the first request from the client to the server. If the server receives a request or notification
+ before the initialize request it should act as follows:
+
+ 1. For a request the response should be an error with code: -32002. The message can be picked by the server.
+ 2. Notifications should be dropped, except for the exit notification. This will allow the exit of a server without an initialize request.
+
+ Until the server has responded to the initialize request with an InitializeResult, the client must not send any additional requests or
+ notifications to the server. In addition the server is not allowed to send any requests or notifications to the client until it has responded
+ with an InitializeResult, with the exception that during the initialize request the server is allowed to send the notifications window/showMessage,
+ window/logMessage and telemetry/event as well as the window/showMessageRequest request to the client.
+
+ The initialize request may only be sent once.
+
+ :param int processId: The process Id of the parent process that started the server. Is null if the process has not been started by another process.
+ If the parent process is not alive then the server should exit (see exit notification) its process.
+ :param str rootPath: The rootPath of the workspace. Is null if no folder is open. Deprecated in favour of rootUri.
+ :param DocumentUri rootUri: The rootUri of the workspace. Is null if no folder is open. If both `rootPath` and `rootUri` are set
+ `rootUri` wins.
+ :param any initializationOptions: User provided initialization options.
+ :param ClientCapabilities capabilities: The capabilities provided by the client (editor or tool).
+ :param Trace trace: The initial trace setting. If omitted trace is disabled ('off').
+ :param list workspaceFolders: The workspace folders configured in the client when the server starts. This property is only available if the client supports workspace folders.
+ It can be `null` if the client supports workspace folders but none are configured.
+ """
+ self.lsp_endpoint.start()
+ return self.lsp_endpoint.call_method(
+ "initialize",
+ processId=processId,
+ rootPath=rootPath,
+ rootUri=rootUri,
+ initializationOptions=initializationOptions,
+ capabilities=capabilities,
+ trace=trace,
+ workspaceFolders=workspaceFolders,
+ )
+
+ def initialized(self):
+ """
+ The initialized notification is sent from the client to the server after the client received the result of the initialize request
+ but before the client is sending any other request or notification to the server. The server can use the initialized notification
+ for example to dynamically register capabilities. The initialized notification may only be sent once.
+ """
+ self.lsp_endpoint.send_notification("initialized")
+
+ def shutdown(self):
+ """
+ The initialized notification is sent from the client to the server after the client received the result of the initialize request
+ but before the client is sending any other request or notification to the server. The server can use the initialized notification
+ for example to dynamically register capabilities. The initialized notification may only be sent once.
+ """
+ self.lsp_endpoint.stop()
+ return self.lsp_endpoint.call_method("shutdown")
+
+ def exit(self):
+ """
+ The initialized notification is sent from the client to the server after the client received the result of the initialize request
+ but before the client is sending any other request or notification to the server. The server can use the initialized notification
+ for example to dynamically register capabilities. The initialized notification may only be sent once.
+ """
+ self.lsp_endpoint.send_notification("exit")
+
+ def didOpen(self, textDocument):
+ """
+ The document open notification is sent from the client to the server to signal newly opened text documents. The document's truth is
+ now managed by the client and the server must not try to read the document's truth using the document's uri. Open in this sense
+ means it is managed by the client. It doesn't necessarily mean that its content is presented in an editor. An open notification must
+ not be sent more than once without a corresponding close notification send before. This means open and close notification must be
+ balanced and the max open count for a particular textDocument is one. Note that a server's ability to fulfill requests is independent
+ of whether a text document is open or closed.
+
+ The DidOpenTextDocumentParams contain the language id the document is associated with. If the language Id of a document changes, the
+ client needs to send a textDocument/didClose to the server followed by a textDocument/didOpen with the new language id if the server
+ handles the new language id as well.
+
+ :param TextDocumentItem textDocument: The initial trace setting. If omitted trace is disabled ('off').
+ """
+ return self.lsp_endpoint.send_notification(
+ "textDocument/didOpen", textDocument=textDocument
+ )
+
+ def documentSymbol(self, textDocument):
+ """
+ The document symbol request is sent from the client to the server to return a flat list of all symbols found in a given text document.
+ Neither the symbol's location range nor the symbol's container name should be used to infer a hierarchy.
+
+ :param TextDocumentItem textDocument: The text document.
+ """
+ result_dict = self.lsp_endpoint.call_method(
+ "textDocument/documentSymbol", textDocument=textDocument
+ )
+ return [SymbolInformation(**sym) for sym in result_dict]
+
+ def definition(self, textDocument, position):
+ """
+ The goto definition request is sent from the client to the server to resolve the definition location of a symbol at a given text document position.
+
+ :param TextDocumentItem textDocument: The text document.
+ :param Position position: The position inside the text document..
+ """
+ result_dict = self.lsp_endpoint.call_method(
+ "textDocument/definition", textDocument=textDocument, position=position
+ )
+ return [Location(**l) for l in result_dict]
+
+ def typeDefinition(self, textDocument, position):
+ """
+ The goto type definition request is sent from the client to the server to resolve the type definition location of a symbol at a given text document position.
+
+ :param TextDocumentItem textDocument: The text document.
+ :param Position position: The position inside the text document..
+ """
+ result_dict = self.lsp_endpoint.call_method(
+ "textDocument/definition", textDocument=textDocument, position=position
+ )
+ return [Location(**l) for l in result_dict]
+
+ def signatureHelp(self, textDocument, position):
+ """
+ The signature help request is sent from the client to the server to request signature information at a given cursor position.
+
+ :param TextDocumentItem textDocument: The text document.
+ :param Position position: The position inside the text document..
+ """
+ result_dict = self.lsp_endpoint.call_method(
+ "textDocument/signatureHelp", textDocument=textDocument, position=position
+ )
+ return SignatureHelp(**result_dict)
diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py
new file mode 100644
index 00000000..14d2ca07
--- /dev/null
+++ b/continuedev/src/continuedev/libs/lspclient/lsp_endpoint.py
@@ -0,0 +1,71 @@
+from __future__ import print_function
+
+import threading
+
+
+class LspEndpoint(threading.Thread):
+ def __init__(self, json_rpc_endpoint, default_callback=print, callbacks={}):
+ threading.Thread.__init__(self)
+ self.json_rpc_endpoint = json_rpc_endpoint
+ self.callbacks = callbacks
+ self.default_callback = default_callback
+ self.event_dict = {}
+ self.response_dict = {}
+ self.next_id = 0
+ # self.daemon = True
+ self.shutdown_flag = False
+
+ def handle_result(self, jsonrpc_res):
+ self.response_dict[jsonrpc_res["id"]] = jsonrpc_res
+ cond = self.event_dict[jsonrpc_res["id"]]
+ cond.acquire()
+ cond.notify()
+ cond.release()
+
+ def stop(self):
+ self.shutdown_flag = True
+
+ def run(self):
+ while not self.shutdown_flag:
+ jsonrpc_message = self.json_rpc_endpoint.recv_response()
+
+ if jsonrpc_message is None:
+ print("server quit")
+ break
+
+ # print("recieved message:", jsonrpc_message)
+ if "result" in jsonrpc_message or "error" in jsonrpc_message:
+ self.handle_result(jsonrpc_message)
+ elif "method" in jsonrpc_message:
+ if jsonrpc_message["method"] in self.callbacks:
+ self.callbacks[jsonrpc_message["method"]](jsonrpc_message)
+ else:
+ self.default_callback(jsonrpc_message)
+ else:
+ print("unknown jsonrpc message")
+ # print(jsonrpc_message)
+
+ def send_message(self, method_name, params, id=None):
+ message_dict = {}
+ message_dict["jsonrpc"] = "2.0"
+ if id is not None:
+ message_dict["id"] = id
+ message_dict["method"] = method_name
+ message_dict["params"] = params
+ self.json_rpc_endpoint.send_request(message_dict)
+
+ def call_method(self, method_name, **kwargs):
+ current_id = self.next_id
+ self.next_id += 1
+ cond = threading.Condition()
+ self.event_dict[current_id] = cond
+ cond.acquire()
+ self.send_message(method_name, kwargs, current_id)
+ cond.wait()
+ cond.release()
+ # TODO: check if error, and throw an exception
+ response = self.response_dict[current_id]
+ return response["result"]
+
+ def send_notification(self, method_name, **kwargs):
+ self.send_message(method_name, kwargs)
diff --git a/continuedev/src/continuedev/libs/lspclient/lsp_structs.py b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py
new file mode 100644
index 00000000..2f0940d4
--- /dev/null
+++ b/continuedev/src/continuedev/libs/lspclient/lsp_structs.py
@@ -0,0 +1,316 @@
+def to_type(o, new_type):
+ '''
+ Helper funciton that receives an object or a dict and convert it to a new given type.
+
+ :param object|dict o: The object to convert
+ :param Type new_type: The type to convert to.
+ '''
+ if new_type == type(o):
+ return o
+ else:
+ return new_type(**o)
+
+
+class Position(object):
+ def __init__(self, line, character):
+ """
+ Constructs a new Position instance.
+
+ :param int line: Line position in a document (zero-based).
+ :param int character: Character offset on a line in a document (zero-based).
+ """
+ self.line = line
+ self.character = character
+
+
+class Range(object):
+ def __init__(self, start, end):
+ """
+ Constructs a new Range instance.
+
+ :param Position start: The range's start position.
+ :param Position end: The range's end position.
+ """
+ self.start = to_type(start, Position)
+ self.end = to_type(end, Position)
+
+
+class Location(object):
+ """
+ Represents a location inside a resource, such as a line inside a text file.
+ """
+ def __init__(self, uri, range):
+ """
+ Constructs a new Range instance.
+
+ :param str uri: Resource file.
+ :param Range range: The range inside the file
+ """
+ self.uri = uri
+ self.range = to_type(range, Range)
+
+
+class Diagnostic(object):
+ def __init__(self, range, severity, code, source, message, relatedInformation):
+ """
+ Constructs a new Diagnostic instance.
+ :param Range range: The range at which the message applies.Resource file.
+ :param int severity: The diagnostic's severity. Can be omitted. If omitted it is up to the
+ client to interpret diagnostics as error, warning, info or hint.
+ :param str code: The diagnostic's code, which might appear in the user interface.
+ :param str source: A human-readable string describing the source of this
+ diagnostic, e.g. 'typescript' or 'super lint'.
+ :param str message: The diagnostic's message.
+ :param list relatedInformation: An array of related diagnostic information, e.g. when symbol-names within
+ a scope collide all definitions can be marked via this property.
+ """
+ self.range = range
+ self.severity = severity
+ self.code = code
+ self.source = source
+ self.message = message
+ self.relatedInformation = relatedInformation
+
+
+class DiagnosticSeverity(object):
+ Error = 1
+ Warning = 2 # TODO: warning is known in python
+ Information = 3
+ Hint = 4
+
+
+class DiagnosticRelatedInformation(object):
+ def __init__(self, location, message):
+ """
+ Constructs a new Diagnostic instance.
+ :param Location location: The location of this related diagnostic information.
+ :param str message: The message of this related diagnostic information.
+ """
+ self.location = location
+ self.message = message
+
+
+class Command(object):
+ def __init__(self, title, command, arguments):
+ """
+ Constructs a new Diagnostic instance.
+ :param str title: Title of the command, like `save`.
+ :param str command: The identifier of the actual command handler.
+ :param list argusments: Arguments that the command handler should be invoked with.
+ """
+ self.title = title
+ self.command = command
+ self.arguments = arguments
+
+
+class TextDocumentItem(object):
+ """
+ An item to transfer a text document from the client to the server.
+ """
+ def __init__(self, uri, languageId, version, text):
+ """
+ Constructs a new Diagnostic instance.
+
+ :param DocumentUri uri: Title of the command, like `save`.
+ :param str languageId: The identifier of the actual command handler.
+ :param int version: Arguments that the command handler should be invoked with.
+ :param str text: Arguments that the command handler should be invoked with.
+ """
+ self.uri = uri
+ self.languageId = languageId
+ self.version = version
+ self.text = text
+
+
+class TextDocumentIdentifier(object):
+ """
+ Text documents are identified using a URI. On the protocol level, URIs are passed as strings.
+ """
+ def __init__(self, uri):
+ """
+ Constructs a new TextDocumentIdentifier instance.
+
+ :param DocumentUri uri: The text document's URI.
+ """
+ self.uri = uri
+
+class TextDocumentPositionParams(object):
+ """
+ A parameter literal used in requests to pass a text document and a position inside that document.
+ """
+ def __init__(self, textDocument, position):
+ """
+ Constructs a new TextDocumentPositionParams instance.
+
+ :param TextDocumentIdentifier textDocument: The text document.
+ :param Position position: The position inside the text document.
+ """
+ self.textDocument = textDocument
+ self.position = position
+
+
+class LANGUAGE_IDENTIFIER:
+ BAT="bat"
+ BIBTEX="bibtex"
+ CLOJURE="clojure"
+ COFFESCRIPT="coffeescript"
+ C="c"
+ CPP="cpp"
+ CSHARP="csharp"
+ CSS="css"
+ DIFF="diff"
+ DOCKERFILE="dockerfile"
+ FSHARP="fsharp"
+ GIT_COMMIT="git-commit"
+ GIT_REBASE="git-rebase"
+ GO="go"
+ GROOVY="groovy"
+ HANDLEBARS="handlebars"
+ HTML="html"
+ INI="ini"
+ JAVA="java"
+ JAVASCRIPT="javascript"
+ JSON="json"
+ LATEX="latex"
+ LESS="less"
+ LUA="lua"
+ MAKEFILE="makefile"
+ MARKDOWN="markdown"
+ OBJECTIVE_C="objective-c"
+ OBJECTIVE_CPP="objective-cpp"
+ Perl="perl"
+ PHP="php"
+ POWERSHELL="powershell"
+ PUG="jade"
+ PYTHON="python"
+ R="r"
+ RAZOR="razor"
+ RUBY="ruby"
+ RUST="rust"
+ SASS="sass"
+ SCSS="scss"
+ ShaderLab="shaderlab"
+ SHELL_SCRIPT="shellscript"
+ SQL="sql"
+ SWIFT="swift"
+ TYPE_SCRIPT="typescript"
+ TEX="tex"
+ VB="vb"
+ XML="xml"
+ XSL="xsl"
+ YAML="yaml"
+
+
+class SymbolKind(object):
+ File = 1
+ Module = 2
+ Namespace = 3
+ Package = 4
+ Class = 5
+ Method = 6
+ Property = 7
+ Field = 8
+ Constructor = 9
+ Enum = 10
+ Interface = 11
+ Function = 12
+ Variable = 13
+ Constant = 14
+ String = 15
+ Number = 16
+ Boolean = 17
+ Array = 18
+ Object = 19
+ Key = 20
+ Null = 21
+ EnumMember = 22
+ Struct = 23
+ Event = 24
+ Operator = 25
+ TypeParameter = 26
+
+
+class SymbolInformation(object):
+ """
+ Represents information about programming constructs like variables, classes, interfaces etc.
+ """
+ def __init__(self, name, kind, location, containerName, deprecated=False):
+ """
+ Constructs a new SymbolInformation instance.
+
+ :param str name: The name of this symbol.
+ :param int kind: The kind of this symbol.
+ :param bool Location: The location of this symbol. The location's range is used by a tool
+ to reveal the location in the editor. If the symbol is selected in the
+ tool the range's start information is used to position the cursor. So
+ the range usually spans more then the actual symbol's name and does
+ normally include things like visibility modifiers.
+
+ The range doesn't have to denote a node range in the sense of a abstract
+ syntax tree. It can therefore not be used to re-construct a hierarchy of
+ the symbols.
+ :param str containerName: The name of the symbol containing this symbol. This information is for
+ user interface purposes (e.g. to render a qualifier in the user interface
+ if necessary). It can't be used to re-infer a hierarchy for the document
+ symbols.
+ :param bool deprecated: Indicates if this symbol is deprecated.
+ """
+ self.name = name
+ self.kind = kind
+ self.deprecated = deprecated
+ self.location = to_type(location, Location)
+ self.containerName = containerName
+
+
+class ParameterInformation(object):
+ """
+ Represents a parameter of a callable-signature. A parameter can
+ have a label and a doc-comment.
+ """
+ def __init__(self, label, documentation=""):
+ """
+ Constructs a new ParameterInformation instance.
+
+ :param str label: The label of this parameter. Will be shown in the UI.
+ :param str documentation: The human-readable doc-comment of this parameter. Will be shown in the UI but can be omitted.
+ """
+ self.label = label
+ self.documentation = documentation
+
+
+class SignatureInformation(object):
+ """
+ Represents the signature of something callable. A signature
+ can have a label, like a function-name, a doc-comment, and
+ a set of parameters.
+ """
+ def __init__(self, label, documentation="", parameters=[]):
+ """
+ Constructs a new SignatureInformation instance.
+
+ :param str label: The label of this signature. Will be shown in the UI.
+ :param str documentation: The human-readable doc-comment of this signature. Will be shown in the UI but can be omitted.
+ :param ParameterInformation[] parameters: The parameters of this signature.
+ """
+ self.label = label
+ self.documentation = documentation
+ self.parameters = [to_type(parameter, ParameterInformation) for parameter in parameters]
+
+
+class SignatureHelp(object):
+ """
+ Signature help represents the signature of something
+ callable. There can be multiple signature but only one
+ active and only one active parameter.
+ """
+ def __init__(self, signatures, activeSignature=0, activeParameter=0):
+ """
+ Constructs a new SignatureHelp instance.
+
+ :param SignatureInformation[] signatures: One or more signatures.
+ :param int activeSignature:
+ :param int activeParameter:
+ """
+ self.signatures = [to_type(signature, SignatureInformation) for signature in signatures]
+ self.activeSignature = activeSignature
+ self.activeParameter = activeParameter \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/filter_files.py b/continuedev/src/continuedev/libs/util/filter_files.py
new file mode 100644
index 00000000..6ebaa274
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/filter_files.py
@@ -0,0 +1,33 @@
+import fnmatch
+from typing import List
+
+DEFAULT_IGNORE_DIRS = [
+ ".git",
+ ".vscode",
+ ".idea",
+ ".vs",
+ ".venv",
+ "env",
+ ".env",
+ "node_modules",
+ "dist",
+ "build",
+ "target",
+ "out",
+ "bin",
+ ".pytest_cache",
+ ".vscode-test",
+ ".continue",
+ "__pycache__",
+]
+
+DEFAULT_IGNORE_PATTERNS = DEFAULT_IGNORE_DIRS + list(
+ filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)
+)
+
+
+def should_filter_path(
+ path: str, ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS
+) -> bool:
+ """Returns whether a file should be filtered"""
+ return any(fnmatch.fnmatch(path, pattern) for pattern in ignore_patterns)
diff --git a/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py
new file mode 100644
index 00000000..58a4f728
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/traceback/traceback_parsers.py
@@ -0,0 +1,56 @@
+from boltons import tbutils
+
+from ....models.main import Traceback
+
+PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):"
+
+
+def get_python_traceback(output: str) -> str:
+ if PYTHON_TRACEBACK_PREFIX in output:
+ tb_string = output.split(PYTHON_TRACEBACK_PREFIX)[-1]
+
+ # Then need to remove any lines below the traceback. Do this by noticing that
+ # the last line of the traceback is the first (other than they prefix) that doesn't begin with whitespace
+ lines = list(filter(lambda x: x.strip() != "", tb_string.splitlines()))
+ for i in range(len(lines) - 1):
+ if not lines[i].startswith(" "):
+ tb_string = "\n".join(lines[: i + 1])
+ break
+
+ return PYTHON_TRACEBACK_PREFIX + "\n" + tb_string
+ elif "SyntaxError" in output:
+ return "SyntaxError" + output.split("SyntaxError")[-1]
+ else:
+ return None
+
+
+def get_javascript_traceback(output: str) -> str:
+ lines = output.splitlines()
+ first_line = None
+ for i in range(len(lines) - 1):
+ segs = lines[i].split(":")
+ if (
+ len(segs) > 1
+ and segs[0] != ""
+ and segs[1].startswith(" ")
+ and lines[i + 1].strip().startswith("at")
+ ):
+ first_line = lines[i]
+ break
+
+ if first_line is not None:
+ return "\n".join(lines[lines.index(first_line) :])
+ else:
+ return None
+
+
+def parse_python_traceback(tb_string: str) -> Traceback:
+ # Remove anchor lines - tbutils doesn't always get them right
+ tb_string = "\n".join(
+ filter(
+ lambda x: x.strip().replace("~", "").replace("^", "") != "",
+ tb_string.splitlines(),
+ )
+ )
+ exc = tbutils.ParsedException.from_string(tb_string)
+ return Traceback.from_tbutil_parsed_exc(exc)
diff --git a/continuedev/src/continuedev/libs/util/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback_parsers.py
deleted file mode 100644
index 2b164c0f..00000000
--- a/continuedev/src/continuedev/libs/util/traceback_parsers.py
+++ /dev/null
@@ -1,30 +0,0 @@
-PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):"
-
-
-def get_python_traceback(output: str) -> str:
- if PYTHON_TRACEBACK_PREFIX in output:
- return PYTHON_TRACEBACK_PREFIX + output.split(PYTHON_TRACEBACK_PREFIX)[-1]
- elif "SyntaxError" in output:
- return "SyntaxError" + output.split("SyntaxError")[-1]
- else:
- return None
-
-
-def get_javascript_traceback(output: str) -> str:
- lines = output.splitlines()
- first_line = None
- for i in range(len(lines) - 1):
- segs = lines[i].split(":")
- if (
- len(segs) > 1
- and segs[0] != ""
- and segs[1].startswith(" ")
- and lines[i + 1].strip().startswith("at")
- ):
- first_line = lines[i]
- break
-
- if first_line is not None:
- return "\n".join(lines[lines.index(first_line) :])
- else:
- return None
diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py
index de426282..514337bf 100644
--- a/continuedev/src/continuedev/models/filesystem.py
+++ b/continuedev/src/continuedev/models/filesystem.py
@@ -136,6 +136,11 @@ class FileSystem(AbstractModel):
"""Apply edit to filesystem, calculate the reverse edit, and return and EditDiff"""
raise NotImplementedError
+ @abstractmethod
+ def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]:
+ """List the contents of a directory"""
+ raise NotImplementedError
+
@classmethod
def read_range_in_str(self, s: str, r: Range) -> str:
lines = s.split("\n")[r.start.line : r.end.line + 1]
@@ -312,6 +317,18 @@ class RealFileSystem(FileSystem):
self.write(edit.filepath, new_content)
return diff
+ def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]:
+ """List the contents of a directory"""
+ if recursive:
+ # Walk
+ paths = []
+ for root, dirs, files in os.walk(path):
+ for f in files:
+ paths.append(os.path.join(root, f))
+
+ return paths
+ return list(map(lambda x: os.path.join(path, x), os.listdir(path)))
+
class VirtualFileSystem(FileSystem):
"""A simulated filesystem from a mapping of filepath to file contents."""
@@ -363,5 +380,16 @@ class VirtualFileSystem(FileSystem):
self.write(edit.filepath, new_content)
return EditDiff(edit=edit, original=original)
+ def list_directory_contents(self, path: str, recursive: bool = False) -> List[str]:
+ """List the contents of a directory"""
+ if recursive:
+ for filepath in self.files:
+ if filepath.startswith(path):
+ yield filepath
+
+ for filepath in self.files:
+ if filepath.startswith(path) and "/" not in filepath[len(path) :]:
+ yield filepath
+
# TODO: Uniform errors thrown by any FileSystem subclass.
diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py
index d442a415..a8b4aab0 100644
--- a/continuedev/src/continuedev/models/main.py
+++ b/continuedev/src/continuedev/models/main.py
@@ -50,6 +50,11 @@ class Position(BaseModel):
return sum(map(len, lines[: self.line])) + self.character
+class PositionInFile(BaseModel):
+ position: Position
+ filepath: str
+
+
class Range(BaseModel):
"""A range in a file. 0-indexed."""
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 9ac28aa2..06d22348 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -5,6 +5,7 @@ from typing import List
from ...core.context import ContextProvider
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...core.sdk import ContinueSDK
+from ...libs.util.filter_files import DEFAULT_IGNORE_PATTERNS
from .util import remove_meilisearch_disallowed_chars
MAX_SIZE_IN_CHARS = 25_000
@@ -18,36 +19,13 @@ async def get_file_contents(filepath: str, sdk: ContinueSDK) -> str:
return None
-DEFAULT_IGNORE_DIRS = [
- ".git",
- ".vscode",
- ".idea",
- ".vs",
- ".venv",
- "env",
- ".env",
- "node_modules",
- "dist",
- "build",
- "target",
- "out",
- "bin",
- ".pytest_cache",
- ".vscode-test",
- ".continue",
- "__pycache__",
-]
-
-
class FileContextProvider(ContextProvider):
"""
The FileContextProvider is a ContextProvider that allows you to search files in the open workspace.
"""
title = "file"
- ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + list(
- filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)
- )
+ ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS
async def start(self, *args):
await super().start(*args)
diff --git a/continuedev/src/continuedev/plugins/policies/commit.py b/continuedev/src/continuedev/plugins/policies/commit.py
new file mode 100644
index 00000000..2fa43676
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/policies/commit.py
@@ -0,0 +1,77 @@
+# An agent that makes a full commit in the background
+# Plans
+# Write code
+# Reviews code
+# Cleans up
+
+# It's important that agents are configurable, because people need to be able to specify
+# which hooks they want to run. Specific linter, run tests, etc.
+# And all of this can be easily specified in the Policy.
+
+
+from textwrap import dedent
+from typing import Literal
+
+from ...core.config import ContinueConfig
+from ...core.main import History, Policy, Step
+from ...core.observation import TextObservation
+from ...core.sdk import ContinueSDK
+
+
+class PlanStep(Step):
+ user_input: str
+
+ _prompt = dedent(
+ """\
+ You were given the following instructions: "{user_input}".
+
+ Create a plan for how you will complete the task.
+
+ Here are relevant files:
+
+ {relevant_files}
+
+ Your plan will include:
+ 1. A high-level description of how you are going to accomplish the task
+ 2. A list of which files you will edit
+ 3. A description of what you will change in each file
+ """
+ )
+
+ async def run(self, sdk: ContinueSDK):
+ plan = await sdk.models.default.complete(
+ self._prompt.format(
+ {"user_input": self.user_input, "relevant_files": "TODO"}
+ )
+ )
+ return TextObservation(text=plan)
+
+
+class WriteCommitStep(Step):
+ async def run(self, sdk: ContinueSDK):
+ pass
+
+
+class ReviewCodeStep(Step):
+ async def run(self, sdk: ContinueSDK):
+ pass
+
+
+class CleanupStep(Step):
+ async def run(self, sdk: ContinueSDK):
+ pass
+
+
+class CommitPolicy(Policy):
+ user_input: str
+
+ current_step: Literal["plan", "write", "review", "cleanup"] = "plan"
+
+ def next(self, config: ContinueConfig, history: History) -> Step:
+ if history.get_current() is None:
+ return (
+ PlanStep(user_input=self.user_input)
+ >> WriteCommitStep()
+ >> ReviewCodeStep()
+ >> CleanupStep()
+ )
diff --git a/continuedev/src/continuedev/plugins/policies/headless.py b/continuedev/src/continuedev/plugins/policies/headless.py
new file mode 100644
index 00000000..56ebe31f
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/policies/headless.py
@@ -0,0 +1,18 @@
+from ...core.config import ContinueConfig
+from ...core.main import History, Policy, Step
+from ...core.observation import TextObservation
+from ...plugins.steps.core.core import ShellCommandsStep
+from ...plugins.steps.on_traceback import DefaultOnTracebackStep
+
+
+class HeadlessPolicy(Policy):
+ command: str
+
+ def next(self, config: ContinueConfig, history: History) -> Step:
+ if history.get_current() is None:
+ return ShellCommandsStep(cmds=[self.command])
+ observation = history.get_current().observation
+ if isinstance(observation, TextObservation):
+ return DefaultOnTracebackStep(output=observation.text)
+
+ return None
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 5325a918..740c75bc 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -1,5 +1,6 @@
# These steps are depended upon by ContinueSDK
import difflib
+import subprocess
import traceback
from textwrap import dedent
from typing import Any, Coroutine, List, Optional, Union
@@ -112,52 +113,22 @@ class ShellCommandsStep(Step):
)
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
-
- for cmd in self.cmds:
- output = await sdk.ide.runCommand(cmd)
- if (
- self.handle_error
- and output is not None
- and output_contains_error(output)
- ):
- suggestion = await sdk.models.medium.complete(
- dedent(
- f"""\
- While running the command `{cmd}`, the following error occurred:
-
- ```ascii
- {output}
- ```
-
- This is a brief summary of the error followed by a suggestion on how it can be fixed:"""
- ),
- with_history=await sdk.get_chat_context(),
- )
-
- sdk.raise_exception(
- title="Error while running query",
- message=output,
- with_step=MessageStep(
- name=f"Suggestion to solve error {AI_ASSISTED_STRING}",
- message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.",
- ),
- )
-
- return TextObservation(text=output)
-
- # process = subprocess.Popen(
- # '/bin/bash', stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=cwd)
+ process = subprocess.Popen(
+ "/bin/bash",
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ cwd=self.cwd or sdk.ide.workspace_directory,
+ )
- # stdin_input = "\n".join(self.cmds)
- # out, err = process.communicate(stdin_input.encode())
+ stdin_input = "\n".join(self.cmds)
+ out, err = process.communicate(stdin_input.encode())
- # # If it fails, return the error
- # if err is not None and err != "":
- # self._err_text = err
- # return TextObservation(text=err)
+ # If it fails, return the error
+ if err is not None and err != "":
+ self._err_text = err
+ return TextObservation(text=err)
- # return None
+ return None
class DefaultModelEditCodeStep(Step):
diff --git a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py
index 1d135b3e..7ceefe9b 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py
@@ -8,6 +8,10 @@ class ImplementAbstractMethodStep(Step):
class_name: str
async def run(self, sdk: ContinueSDK):
+ if sdk.lsp is None:
+ self.description = "Language Server Protocol is not enabled"
+ return
+
implementations = await sdk.lsp.go_to_implementations(self.class_name)
for implementation in implementations:
diff --git a/continuedev/src/continuedev/plugins/steps/on_traceback.py b/continuedev/src/continuedev/plugins/steps/on_traceback.py
index 63bae805..6b75d726 100644
--- a/continuedev/src/continuedev/plugins/steps/on_traceback.py
+++ b/continuedev/src/continuedev/plugins/steps/on_traceback.py
@@ -1,15 +1,28 @@
import os
+from textwrap import dedent
+from typing import Dict, List, Optional, Tuple
from ...core.main import ChatMessage, Step
from ...core.sdk import ContinueSDK
-from ...libs.util.traceback_parsers import (
+from ...libs.util.filter_files import should_filter_path
+from ...libs.util.traceback.traceback_parsers import (
get_javascript_traceback,
get_python_traceback,
+ parse_python_traceback,
)
+from ...models.filesystem import RangeInFile
+from ...models.main import Range, Traceback, TracebackFrame
from .chat import SimpleChatStep
from .core.core import UserInputStep
+def extract_traceback_str(output: str) -> str:
+ tb = output.strip()
+ for tb_parser in [get_python_traceback, get_javascript_traceback]:
+ if parsed_tb := tb_parser(tb):
+ return parsed_tb
+
+
class DefaultOnTracebackStep(Step):
output: str
name: str = "Help With Traceback"
@@ -38,11 +51,11 @@ class DefaultOnTracebackStep(Step):
# And this function is where you can get arbitrarily fancy about adding context
async def run(self, sdk: ContinueSDK):
- tb = self.output.strip()
- for tb_parser in [get_python_traceback, get_javascript_traceback]:
- if parsed_tb := tb_parser(tb):
- tb = parsed_tb
- break
+ if get_python_traceback(self.output) is not None:
+ await sdk.run_step(SolvePythonTracebackStep(output=self.output))
+ return
+
+ tb = extract_traceback_str(self.output)
tb_first_last_lines = (
("\n".join(tb.split("\n")[:3]) + "\n...\n" + "\n".join(tb.split("\n")[-3:]))
@@ -57,3 +70,139 @@ class DefaultOnTracebackStep(Step):
)
)
await sdk.run_step(SimpleChatStep(name="Help With Traceback"))
+
+
+def filter_frames(frames: List[TracebackFrame]) -> List[TracebackFrame]:
+ """Filter out frames that are not relevant to the user's code."""
+ return list(filter(lambda x: should_filter_path(x.filepath), frames))
+
+
+def find_external_call(
+ frames: List[TracebackFrame],
+) -> Optional[Tuple[TracebackFrame, TracebackFrame]]:
+ """Moving up from the bottom of the stack, if the frames are not user code, then find the last frame before it becomes user code."""
+ if not should_filter_path(frames[-1].filepath):
+ # No external call, error comes directly from user code
+ return None
+
+ for i in range(len(frames) - 2, -1, -1):
+ if not should_filter_path(frames[i].filepath):
+ return frames[i], frames[i + 1]
+
+
+def get_func_source_for_frame(frame: Dict) -> str:
+ """Get the source for the function called in the frame."""
+ pass
+
+
+async def fetch_docs_for_external_call(external_call: Dict, next_frame: Dict) -> str:
+ """Fetch docs for the external call."""
+ pass
+
+
+class SolvePythonTracebackStep(Step):
+ output: str
+ name: str = "Solve Traceback"
+ hide: bool = True
+
+ async def external_call_prompt(
+ self, sdk: ContinueSDK, external_call: Tuple[Dict, Dict], tb_string: str
+ ) -> str:
+ external_call, next_frame = external_call
+ source_line = external_call["source_line"]
+ external_func_source = get_func_source_for_frame(next_frame)
+ docs = await fetch_docs_for_external_call(external_call, next_frame)
+
+ prompt = dedent(
+ f"""\
+ I got the following error:
+
+ {tb_string}
+
+ I tried to call an external library like this:
+
+ ```python
+ {source_line}
+ ```
+
+ This is the definition of the function I tried to call:
+
+ ```python
+ {external_func_source}
+ ```
+
+ Here's the documentation for the external library I tried to call:
+
+ {docs}
+
+ Explain how to fix the error.
+ """
+ )
+
+ return prompt
+
+ async def normal_traceback_prompt(
+ self, sdk: ContinueSDK, tb: Traceback, tb_string: str
+ ) -> str:
+ function_bodies = await get_functions_from_traceback(tb, sdk)
+
+ prompt = (
+ "Here are the functions from the traceback (most recent call last):\n\n"
+ )
+ for i, function_body in enumerate(function_bodies):
+ prompt += f'File "{tb.frames[i].filepath}", line {tb.frames[i].lineno}, in {tb.frames[i].function}\n\n```python\n{function_body or tb.frames[i].code}\n```\n\n'
+
+ prompt += (
+ "Here is the traceback:\n\n```\n"
+ + tb_string
+ + "\n```\n\nExplain how to fix the error."
+ )
+
+ return prompt
+
+ async def run(self, sdk: ContinueSDK):
+ tb_string = get_python_traceback(self.output)
+ tb = parse_python_traceback(tb_string)
+
+ if external_call := find_external_call(tb.frames):
+ prompt = await self.external_call_prompt(sdk, external_call, tb_string)
+ else:
+ prompt = await self.normal_traceback_prompt(sdk, tb, tb_string)
+
+ await sdk.run_step(
+ UserInputStep(
+ description="Solving stack trace",
+ user_input=prompt,
+ )
+ )
+ await sdk.run_step(SimpleChatStep(name="Help With Traceback"))
+
+
+async def get_function_body(frame: TracebackFrame, sdk: ContinueSDK) -> Optional[str]:
+ """Get the function body from the traceback frame."""
+ if sdk.lsp is None:
+ return None
+
+ document_symbols = sdk.lsp.get_symbols(frame.filepath)
+ for symbol in document_symbols:
+ if symbol.name == frame.function:
+ r = symbol.location.range
+ return await sdk.ide.readRangeInFile(
+ RangeInFile(
+ filepath=frame.filepath,
+ range=Range.from_shorthand(
+ r.start.line, r.start.character, r.end.line, r.end.character
+ ),
+ )
+ )
+ return None
+
+
+async def get_functions_from_traceback(tb: Traceback, sdk: ContinueSDK) -> List[str]:
+ """Get the function bodies from the traceback."""
+ function_bodies = []
+ for frame in tb.frames:
+ if frame.function:
+ function_bodies.append(await get_function_body(frame, sdk))
+
+ return function_bodies
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 11099494..6aae8cc5 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -117,10 +117,15 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool:
await asyncio.sleep(frequency)
+meilisearch_process = None
+
+
async def start_meilisearch():
"""
Starts the MeiliSearch server, wait for it.
"""
+ global meilisearch_process
+
serverPath = getServerFolderPath()
# Check if MeiliSearch is installed, if not download
@@ -130,7 +135,7 @@ async def start_meilisearch():
if not await check_meilisearch_running() or not was_already_installed:
logger.debug("Starting MeiliSearch...")
binary_name = "meilisearch" if os.name == "nt" else "./meilisearch"
- subprocess.Popen(
+ meilisearch_process = subprocess.Popen(
[binary_name, "--no-analytics"],
cwd=serverPath,
stdout=subprocess.DEVNULL,
@@ -139,3 +144,14 @@ async def start_meilisearch():
start_new_session=True,
shell=True,
)
+
+
+def stop_meilisearch():
+ """
+ Stops the MeiliSearch server.
+ """
+ global meilisearch_process
+ if meilisearch_process is not None:
+ meilisearch_process.terminate()
+ meilisearch_process.wait()
+ meilisearch_process = None
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index f0c33929..f0080104 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -8,6 +8,7 @@ from fastapi import APIRouter, WebSocket
from fastapi.websockets import WebSocketState
from ..core.autopilot import Autopilot
+from ..core.config import ContinueConfig
from ..core.main import FullState
from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
@@ -57,7 +58,10 @@ class SessionManager:
return self.sessions[session_id]
async def new_session(
- self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None
+ self,
+ ide: AbstractIdeProtocolServer,
+ session_id: Optional[str] = None,
+ config: Optional[ContinueConfig] = None,
) -> Session:
logger.debug(f"New session: {session_id}")
@@ -85,7 +89,7 @@ class SessionManager:
# Start the autopilot (must be after session is added to sessions) and the policy
try:
- await autopilot.start(full_state=full_state)
+ await autopilot.start(full_state=full_state, config=config)
except Exception as e:
await ide.on_error(e)
diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py
index 91ddd33f..f4aea1fb 100644
--- a/continuedev/src/continuedev/tests/llm_test.py
+++ b/continuedev/src/continuedev/tests/llm_test.py
@@ -13,10 +13,11 @@ from continuedev.libs.llm.openai import OpenAI
from continuedev.libs.llm.together import TogetherLLM
from continuedev.libs.util.count_tokens import DEFAULT_ARGS
from continuedev.tests.util.openai_mock import start_openai
+from continuedev.tests.util.prompts import tokyo_test_pair
load_dotenv()
-TEST_PROMPT = "Output a single word, that being the capital of Japan:"
+
SPEND_MONEY = True
@@ -65,9 +66,9 @@ class TestBaseLLM:
if self.llm.__class__.__name__ == "LLM":
pytest.skip("Skipping abstract LLM")
- resp = await self.llm.complete(TEST_PROMPT, temperature=0.0)
+ resp = await self.llm.complete(tokyo_test_pair[0], temperature=0.0)
assert isinstance(resp, str)
- assert resp.strip().lower() == "tokyo"
+ assert resp.strip().lower() == tokyo_test_pair[1]
@pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
@async_test
@@ -79,7 +80,9 @@ class TestBaseLLM:
role = None
async for chunk in self.llm.stream_chat(
messages=[
- ChatMessage(role="user", content=TEST_PROMPT, summary=TEST_PROMPT)
+ ChatMessage(
+ role="user", content=tokyo_test_pair[0], summary=tokyo_test_pair[0]
+ )
],
temperature=0.0,
):
@@ -90,7 +93,7 @@ class TestBaseLLM:
role = chunk["role"]
assert role == "assistant"
- assert completion.strip().lower() == "tokyo"
+ assert completion.strip().lower() == tokyo_test_pair[1]
@pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
@async_test
@@ -99,11 +102,13 @@ class TestBaseLLM:
pytest.skip("Skipping abstract LLM")
completion = ""
- async for chunk in self.llm.stream_complete(TEST_PROMPT, temperature=0.0):
+ async for chunk in self.llm.stream_complete(
+ tokyo_test_pair[0], temperature=0.0
+ ):
assert isinstance(chunk, str)
completion += chunk
- assert completion.strip().lower() == "tokyo"
+ assert completion.strip().lower() == tokyo_test_pair[1]
class TestOpenAI(TestBaseLLM):
@@ -129,7 +134,7 @@ class TestOpenAI(TestBaseLLM):
"Output a single word, that being the capital of Japan:"
)
assert isinstance(resp, str)
- assert resp.strip().lower() == "tokyo"
+ assert resp.strip().lower() == tokyo_test_pair[1]
class TestGGML(TestBaseLLM):
diff --git a/continuedev/src/continuedev/tests/step_test.py b/continuedev/src/continuedev/tests/step_test.py
new file mode 100644
index 00000000..a4131e61
--- /dev/null
+++ b/continuedev/src/continuedev/tests/step_test.py
@@ -0,0 +1,43 @@
+import pytest
+
+from continuedev.core.config import ContinueConfig
+from continuedev.headless import start_headless_session
+from continuedev.plugins.steps.core.core import UserInputStep
+from continuedev.tests.util.prompts import tokyo_test_pair
+
+TEST_CONFIG = ContinueConfig()
+
+
+@pytest.mark.asyncio
+async def test_step():
+ session = await start_headless_session(config=TEST_CONFIG)
+
+ await session.autopilot.run_from_step(UserInputStep(user_input=tokyo_test_pair[0]))
+
+ full_state = await session.autopilot.get_full_state()
+ assert (
+ full_state.history.timeline[-1].step.description.strip().lower()
+ == tokyo_test_pair[1]
+ )
+
+ await session.autopilot.cleanup()
+
+
+# TODO: Test other properties of full_state after the UserInputStep. Also test with other config properties and models, etc...
+# so we are sure that UserInputStep works in many cases. One example of a thing to check is that the step following UserInputStep
+# is a SimpleChatStep, and that it is not hidden, and that the properties of the node (full_state.history.timeline[-1]) all look good
+# (basically, run it, see what you get, then codify this in assertions)
+
+# TODO: Write tests for other steps:
+# - DefaultOnTracebackStep
+# - EditHighlightedCodeStep (note that this won't test the rendering in IDE)
+
+# NOTE: Avoid expensive prompts - not too big a deal, but don't have it generate an entire 100 line file
+# If you want to not have llm_test.py spend money, change SPEND_MONEY to False at the
+# top of that file, but make sure to put it back to True before committing
+
+# NOTE: Headless mode uses continuedev.src.continuedev.headless.headless_ide instead of
+# VS Code, so many of the methods just pass, or might not act exactly how you expect.
+# See the file for reference
+
+# NOTE: If this is too short or pointless a task, let me know and I'll set up testing of ContextProviders
diff --git a/continuedev/src/continuedev/tests/util/config.py b/continuedev/src/continuedev/tests/util/config.py
new file mode 100644
index 00000000..73d3aeff
--- /dev/null
+++ b/continuedev/src/continuedev/tests/util/config.py
@@ -0,0 +1,19 @@
+from continuedev.src.continuedev.core.config import ContinueConfig
+from continuedev.src.continuedev.core.models import Models
+from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+
+config = ContinueConfig(
+ allow_anonymous_telemetry=False,
+ models=Models(
+ default=MaybeProxyOpenAI(api_key="", model="gpt-4"),
+ medium=MaybeProxyOpenAI(
+ api_key="",
+ model="gpt-3.5-turbo",
+ ),
+ ),
+ system_message=None,
+ temperature=0.5,
+ custom_commands=[],
+ slash_commands=[],
+ context_providers=[],
+)
diff --git a/continuedev/src/continuedev/tests/util/prompts.py b/continuedev/src/continuedev/tests/util/prompts.py
new file mode 100644
index 00000000..f60c926c
--- /dev/null
+++ b/continuedev/src/continuedev/tests/util/prompts.py
@@ -0,0 +1 @@
+tokyo_test_pair = ("Output a single word, that being the capital of Japan:", "tokyo") \ No newline at end of file
diff --git a/run.spec b/run.spec
index e3c6c07c..19181249 100644
--- a/run.spec
+++ b/run.spec
@@ -12,8 +12,8 @@ a = Analysis(
datas=[
('continuedev', 'continuedev'),
(certifi.where(), 'ca_bundle')
- ] + copy_metadata('replicate'),
- hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'replicate'],
+ ],
+ hiddenimports=['anthropic', 'github', 'ripgrepy', 'bs4', 'redbaron', 'python-lsp-server', 'replicate'] + copy_metadata('replicate'),
hookspath=[],
hooksconfig={},
runtime_hooks=[],
diff --git a/test.py b/test.py
new file mode 100644
index 00000000..5bd57e0e
--- /dev/null
+++ b/test.py
@@ -0,0 +1,21 @@
+import unittest
+
+
+def sort_numbers(numbers):
+ for i in range(len(numbers)):
+ for j in range(i + 1, len(numbers)):
+ if numbers[i] > numbers[j]:
+ numbers[i], numbers[j] = numbers[j], numbers[i]
+ return numbers[:-1] # Error here: We're not returning the last number
+
+
+class TestSortNumbers(unittest.TestCase):
+ def test_sort_numbers(self):
+ self.assertEqual(sort_numbers([3, 2, 1]), [1, 2, 3]) # This test will fail
+ self.assertEqual(
+ sort_numbers([4, 2, 5, 1, 3]), [1, 2, 3, 4, 5]
+ ) # This test will fail
+
+
+if __name__ == "__main__":
+ unittest.main()