diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/main.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 67 | ||||
-rw-r--r-- | extension/src/activation/environmentSetup.ts | 35 | ||||
-rw-r--r-- | extension/src/continueIdeClient.ts | 89 |
11 files changed, 153 insertions, 90 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 6625065e..e6f38cd0 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -7,7 +7,7 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens -from .prompts.chat import code_llama_template_messages +from .prompts.chat import llama2_template_messages class LlamaCpp(LLM): @@ -15,7 +15,7 @@ class LlamaCpp(LLM): server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None - template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "} use_command: Optional[str] = None diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 464c6420..a61103b9 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -148,10 +148,7 @@ class OpenAI(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True - # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? - args["model"] = ( - self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" - ) + if not args["model"].endswith("0613") and "functions" in args: del args["functions"] diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index b3e9ecc1..9f3117d0 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -31,6 +31,12 @@ def getServerFolderPath(): return path +def getMeilisearchExePath(): + binary_name = "meilisearch.exe" if os.name == "nt" else "meilisearch" + path = os.path.join(getServerFolderPath(), binary_name) + return path + + def getSessionFilePath(session_id: str): path = os.path.join(getSessionsFolderPath(), f"{session_id}.json") os.makedirs(os.path.dirname(path), exist_ok=True) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 9846dd3e..859088b8 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -123,7 +123,7 @@ class FileContextProvider(ContextProvider): ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: - contents = await self.sdk.ide.listDirectoryContents(workspace_dir) + contents = await self.sdk.ide.listDirectoryContents(workspace_dir, True) if contents is None: return [] diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 17b325ab..1529fe1b 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -2,12 +2,13 @@ import difflib import traceback from textwrap import dedent -from typing import Any, Coroutine, List, Union +from typing import Any, Coroutine, List, Optional, Union from pydantic import validator from ....core.main import ChatMessage, ContinueCustomException, Step from ....core.observation import Observation, TextObservation, UserInputObservation +from ....libs.llm import LLM from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import ( @@ -161,6 +162,7 @@ class ShellCommandsStep(Step): class DefaultModelEditCodeStep(Step): user_input: str + model: Optional[LLM] = None range_in_files: List[RangeInFile] name: str = "Editing Code" hide = False @@ -241,7 +243,10 @@ class DefaultModelEditCodeStep(Step): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.edit + if self.model is not None: + await sdk.start_model(self.model) + + model_to_use = self.model or sdk.models.edit max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 @@ -836,6 +841,7 @@ class EditFileStep(Step): filepath: str prompt: str hide: bool = True + model: Optional[LLM] = None async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing file: " + self.filepath @@ -848,6 +854,7 @@ class EditFileStep(Step): RangeInFile.from_entire_file(self.filepath, file_contents) ], user_input=self.prompt, + model=self.model, ) ) diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index ab5775c6..2ceb82c5 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -1,12 +1,13 @@ import os from textwrap import dedent -from typing import Coroutine, List, Union +from typing import Coroutine, List, Optional, Union from pydantic import BaseModel, Field from ...core.main import ContinueCustomException, Step from ...core.observation import Observation from ...core.sdk import ContinueSDK, Models +from ...libs.llm import LLM from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...libs.util.calculate_diff import calculate_diff2 from ...libs.util.logging import logger @@ -240,6 +241,7 @@ class EditHighlightedCodeStep(Step): title="User Input", description="The natural language request describing how to edit the code", ) + model: Optional[LLM] = None hide = True description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change." @@ -293,7 +295,9 @@ class EditHighlightedCodeStep(Step): await sdk.run_step( DefaultModelEditCodeStep( - user_input=self.user_input, range_in_files=range_in_files + user_input=self.user_input, + range_in_files=range_in_files, + model=self.model, ) ) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 8a62c39e..89fcd0d1 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -494,10 +494,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): ) return resp.fileEdit - async def listDirectoryContents(self, directory: str) -> List[str]: + async def listDirectoryContents( + self, directory: str, recursive: bool = False + ) -> List[str]: """List the contents of a directory""" resp = await self._send_and_receive_json( - {"directory": directory}, + {"directory": directory, "recursive": recursive}, ListDirectoryContentsResponse, "listDirectoryContents", ) @@ -574,7 +576,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): # Start meilisearch try: - await start_meilisearch() + + async def on_err(e): + logger.debug(f"Failed to start MeiliSearch: {e}") + + create_async_task(start_meilisearch(), on_err) except Exception as e: logger.debug("Failed to start MeiliSearch") logger.debug(e) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 4ef4bde7..2a07ae2a 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -148,7 +148,9 @@ class AbstractIdeProtocolServer(ABC): """Called when a file is saved""" @abstractmethod - async def listDirectoryContents(self, directory: str) -> List[str]: + async def listDirectoryContents( + self, directory: str, recursive: bool = False + ) -> List[str]: """List directory contents""" @abstractmethod diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 037ce8fa..390eeb50 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -3,20 +3,59 @@ import os import shutil import subprocess +import aiofiles +import aiohttp from meilisearch_python_async import Client from ..libs.util.logging import logger -from ..libs.util.paths import getServerFolderPath +from ..libs.util.paths import getMeilisearchExePath, getServerFolderPath -def ensure_meilisearch_installed() -> bool: +async def download_file(url: str, filename: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + if resp.status == 200: + f = await aiofiles.open(filename, mode="wb") + await f.write(await resp.read()) + await f.close() + + +async def download_meilisearch(): + """ + Downloads MeiliSearch. + """ + + serverPath = getServerFolderPath() + logger.debug("Downloading MeiliSearch...") + + if os.name == "nt": + download_url = "https://github.com/meilisearch/meilisearch/releases/download/v1.3.2/meilisearch-windows-amd64.exe" + download_path = getMeilisearchExePath() + if not os.path.exists(download_path): + await download_file(download_url, download_path) + # subprocess.run( + # f"curl -L {download_url} -o {download_path}", + # shell=True, + # check=True, + # cwd=serverPath, + # ) + else: + subprocess.run( + "curl -L https://install.meilisearch.com | sh", + shell=True, + check=True, + cwd=serverPath, + ) + + +async def ensure_meilisearch_installed() -> bool: """ Checks if MeiliSearch is installed. Returns a bool indicating whether it was installed to begin with. """ serverPath = getServerFolderPath() - meilisearchPath = os.path.join(serverPath, "meilisearch") + meilisearchPath = getMeilisearchExePath() dumpsPath = os.path.join(serverPath, "dumps") dataMsPath = os.path.join(serverPath, "data.ms") @@ -40,14 +79,7 @@ def ensure_meilisearch_installed() -> bool: for p in existing_paths: shutil.rmtree(p, ignore_errors=True) - # Download MeiliSearch - logger.debug("Downloading MeiliSearch...") - subprocess.run( - "curl -L https://install.meilisearch.com | sh", - shell=True, - check=True, - cwd=serverPath, - ) + await download_meilisearch() return False @@ -66,7 +98,7 @@ async def check_meilisearch_running() -> bool: if resp.status != "available": return False return True - except Exception as e: + except Exception: return False except Exception: return False @@ -86,24 +118,21 @@ async def start_meilisearch(): """ Starts the MeiliSearch server, wait for it. """ - - # Doesn't work on windows for now - if not os.name == "posix": - return - serverPath = getServerFolderPath() # Check if MeiliSearch is installed, if not download - was_already_installed = ensure_meilisearch_installed() + was_already_installed = await ensure_meilisearch_installed() # Check if MeiliSearch is running if not await check_meilisearch_running() or not was_already_installed: logger.debug("Starting MeiliSearch...") + binary_name = "meilisearch" if os.name == "nt" else "./meilisearch" subprocess.Popen( - ["./meilisearch", "--no-analytics"], + [binary_name, "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, close_fds=True, start_new_session=True, + shell=True ) diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts index 7ca87768..3aa536d0 100644 --- a/extension/src/activation/environmentSetup.ts +++ b/extension/src/activation/environmentSetup.ts @@ -74,6 +74,14 @@ function serverVersionPath(): string { return path.join(serverPath(), "server_version.txt"); } +function serverBinaryPath(): string { + return path.join( + serverPath(), + "exe", + `run${os.platform() === "win32" ? ".exe" : ""}` + ); +} + export function getExtensionVersion() { const extension = vscode.extensions.getExtension("continue.continue"); return extension?.packageJSON.version || ""; @@ -105,14 +113,7 @@ async function checkOrKillRunningServer(serverUrl: string): Promise<boolean> { // Try again, on Windows. This time with taskkill if (os.platform() === "win32") { try { - const exePath = path.join( - getExtensionUri().fsPath, - "server", - "exe", - "run.exe" - ); - - await runCommand(`taskkill /F /IM ${exePath}`); + await runCommand(`taskkill /F /IM run.exe`); } catch (e: any) { console.log( "Failed to kill old server second time on windows with taskkill:", @@ -126,14 +127,9 @@ async function checkOrKillRunningServer(serverUrl: string): Promise<boolean> { fs.unlinkSync(serverVersionPath()); } // Also delete the server binary - const serverBinaryPath = path.join( - getExtensionUri().fsPath, - "server", - "exe", - `run${os.platform() === "win32" ? ".exe" : ""}` - ); - if (fs.existsSync(serverBinaryPath)) { - fs.unlinkSync(serverBinaryPath); + const serverBinary = serverBinaryPath(); + if (fs.existsSync(serverBinary)) { + fs.unlinkSync(serverBinary); } } @@ -213,12 +209,7 @@ export async function startContinuePythonServer(redownload: boolean = true) { : "mac/run" : "linux/run"; - const destination = path.join( - getExtensionUri().fsPath, - "server", - "exe", - `run${os.platform() === "win32" ? ".exe" : ""}` - ); + const destination = serverBinaryPath(); // First, check if the server is already downloaded let shouldDownload = true; diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts index 94997d76..353584e9 100644 --- a/extension/src/continueIdeClient.ts +++ b/extension/src/continueIdeClient.ts @@ -272,40 +272,10 @@ class IdeProtocolClient { break; case "listDirectoryContents": messenger.send("listDirectoryContents", { - contents: ( - await vscode.workspace.fs.readDirectory( - uriFromFilePath(data.directory) - ) - ) - .map(([name, type]) => name) - .filter((name) => { - const DEFAULT_IGNORE_DIRS = [ - ".git", - ".vscode", - ".idea", - ".vs", - ".venv", - "env", - ".env", - "node_modules", - "dist", - "build", - "target", - "out", - "bin", - ".pytest_cache", - ".vscode-test", - ".continue", - "__pycache__", - ]; - if ( - !DEFAULT_IGNORE_DIRS.some((dir) => - name.split(path.sep).includes(dir) - ) - ) { - return name; - } - }), + contents: await this.getDirectoryContents( + data.directory, + data.recursive || false + ), }); break; case "editFile": @@ -562,6 +532,57 @@ class IdeProtocolClient { }); } + async getDirectoryContents( + directory: string, + recursive: boolean + ): Promise<string[]> { + let nameAndType = ( + await vscode.workspace.fs.readDirectory(uriFromFilePath(directory)) + ).filter(([name, type]) => { + const DEFAULT_IGNORE_DIRS = [ + ".git", + ".vscode", + ".idea", + ".vs", + ".venv", + "env", + ".env", + "node_modules", + "dist", + "build", + "target", + "out", + "bin", + ".pytest_cache", + ".vscode-test", + ".continue", + "__pycache__", + ]; + if ( + !DEFAULT_IGNORE_DIRS.some((dir) => name.split(path.sep).includes(dir)) + ) { + return name; + } + }); + + let absolutePaths = nameAndType + .filter(([name, type]) => type === vscode.FileType.File) + .map(([name, type]) => path.join(directory, name)); + if (recursive) { + for (const [name, type] of nameAndType) { + if (type === vscode.FileType.Directory) { + const subdirectory = path.join(directory, name); + const subdirectoryContents = await this.getDirectoryContents( + subdirectory, + recursive + ); + absolutePaths = absolutePaths.concat(subdirectoryContents); + } + } + } + return absolutePaths; + } + async readFile(filepath: string): Promise<string> { let contents: string | undefined; if (typeof contents === "undefined") { |