summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py5
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py6
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py11
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py8
-rw-r--r--continuedev/src/continuedev/server/ide.py12
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py4
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py67
-rw-r--r--extension/src/activation/environmentSetup.ts35
-rw-r--r--extension/src/continueIdeClient.ts89
11 files changed, 153 insertions, 90 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index 6625065e..e6f38cd0 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -7,7 +7,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
-from .prompts.chat import code_llama_template_messages
+from .prompts.chat import llama2_template_messages
class LlamaCpp(LLM):
@@ -15,7 +15,7 @@ class LlamaCpp(LLM):
server_url: str = "http://localhost:8080"
verify_ssl: Optional[bool] = None
- template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+ template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "}
use_command: Optional[str] = None
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 464c6420..a61103b9 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -148,10 +148,7 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
- args["model"] = (
- self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
- )
+
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index b3e9ecc1..9f3117d0 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -31,6 +31,12 @@ def getServerFolderPath():
return path
+def getMeilisearchExePath():
+ binary_name = "meilisearch.exe" if os.name == "nt" else "meilisearch"
+ path = os.path.join(getServerFolderPath(), binary_name)
+ return path
+
+
def getSessionFilePath(session_id: str):
path = os.path.join(getSessionsFolderPath(), f"{session_id}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 9846dd3e..859088b8 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -123,7 +123,7 @@ class FileContextProvider(ContextProvider):
)
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
- contents = await self.sdk.ide.listDirectoryContents(workspace_dir)
+ contents = await self.sdk.ide.listDirectoryContents(workspace_dir, True)
if contents is None:
return []
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 17b325ab..1529fe1b 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -2,12 +2,13 @@
import difflib
import traceback
from textwrap import dedent
-from typing import Any, Coroutine, List, Union
+from typing import Any, Coroutine, List, Optional, Union
from pydantic import validator
from ....core.main import ChatMessage, ContinueCustomException, Step
from ....core.observation import Observation, TextObservation, UserInputObservation
+from ....libs.llm import LLM
from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import (
@@ -161,6 +162,7 @@ class ShellCommandsStep(Step):
class DefaultModelEditCodeStep(Step):
user_input: str
+ model: Optional[LLM] = None
range_in_files: List[RangeInFile]
name: str = "Editing Code"
hide = False
@@ -241,7 +243,10 @@ class DefaultModelEditCodeStep(Step):
# We don't know here all of the functions being passed in.
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
- model_to_use = sdk.models.edit
+ if self.model is not None:
+ await sdk.start_model(self.model)
+
+ model_to_use = self.model or sdk.models.edit
max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
@@ -836,6 +841,7 @@ class EditFileStep(Step):
filepath: str
prompt: str
hide: bool = True
+ model: Optional[LLM] = None
async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing file: " + self.filepath
@@ -848,6 +854,7 @@ class EditFileStep(Step):
RangeInFile.from_entire_file(self.filepath, file_contents)
],
user_input=self.prompt,
+ model=self.model,
)
)
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ab5775c6..2ceb82c5 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,12 +1,13 @@
import os
from textwrap import dedent
-from typing import Coroutine, List, Union
+from typing import Coroutine, List, Optional, Union
from pydantic import BaseModel, Field
from ...core.main import ContinueCustomException, Step
from ...core.observation import Observation
from ...core.sdk import ContinueSDK, Models
+from ...libs.llm import LLM
from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...libs.util.calculate_diff import calculate_diff2
from ...libs.util.logging import logger
@@ -240,6 +241,7 @@ class EditHighlightedCodeStep(Step):
title="User Input",
description="The natural language request describing how to edit the code",
)
+ model: Optional[LLM] = None
hide = True
description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change."
@@ -293,7 +295,9 @@ class EditHighlightedCodeStep(Step):
await sdk.run_step(
DefaultModelEditCodeStep(
- user_input=self.user_input, range_in_files=range_in_files
+ user_input=self.user_input,
+ range_in_files=range_in_files,
+ model=self.model,
)
)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 8a62c39e..89fcd0d1 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -494,10 +494,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
)
return resp.fileEdit
- async def listDirectoryContents(self, directory: str) -> List[str]:
+ async def listDirectoryContents(
+ self, directory: str, recursive: bool = False
+ ) -> List[str]:
"""List the contents of a directory"""
resp = await self._send_and_receive_json(
- {"directory": directory},
+ {"directory": directory, "recursive": recursive},
ListDirectoryContentsResponse,
"listDirectoryContents",
)
@@ -574,7 +576,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
# Start meilisearch
try:
- await start_meilisearch()
+
+ async def on_err(e):
+ logger.debug(f"Failed to start MeiliSearch: {e}")
+
+ create_async_task(start_meilisearch(), on_err)
except Exception as e:
logger.debug("Failed to start MeiliSearch")
logger.debug(e)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index 4ef4bde7..2a07ae2a 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -148,7 +148,9 @@ class AbstractIdeProtocolServer(ABC):
"""Called when a file is saved"""
@abstractmethod
- async def listDirectoryContents(self, directory: str) -> List[str]:
+ async def listDirectoryContents(
+ self, directory: str, recursive: bool = False
+ ) -> List[str]:
"""List directory contents"""
@abstractmethod
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 037ce8fa..390eeb50 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -3,20 +3,59 @@ import os
import shutil
import subprocess
+import aiofiles
+import aiohttp
from meilisearch_python_async import Client
from ..libs.util.logging import logger
-from ..libs.util.paths import getServerFolderPath
+from ..libs.util.paths import getMeilisearchExePath, getServerFolderPath
-def ensure_meilisearch_installed() -> bool:
+async def download_file(url: str, filename: str):
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url) as resp:
+ if resp.status == 200:
+ f = await aiofiles.open(filename, mode="wb")
+ await f.write(await resp.read())
+ await f.close()
+
+
+async def download_meilisearch():
+ """
+ Downloads MeiliSearch.
+ """
+
+ serverPath = getServerFolderPath()
+ logger.debug("Downloading MeiliSearch...")
+
+ if os.name == "nt":
+ download_url = "https://github.com/meilisearch/meilisearch/releases/download/v1.3.2/meilisearch-windows-amd64.exe"
+ download_path = getMeilisearchExePath()
+ if not os.path.exists(download_path):
+ await download_file(download_url, download_path)
+ # subprocess.run(
+ # f"curl -L {download_url} -o {download_path}",
+ # shell=True,
+ # check=True,
+ # cwd=serverPath,
+ # )
+ else:
+ subprocess.run(
+ "curl -L https://install.meilisearch.com | sh",
+ shell=True,
+ check=True,
+ cwd=serverPath,
+ )
+
+
+async def ensure_meilisearch_installed() -> bool:
"""
Checks if MeiliSearch is installed.
Returns a bool indicating whether it was installed to begin with.
"""
serverPath = getServerFolderPath()
- meilisearchPath = os.path.join(serverPath, "meilisearch")
+ meilisearchPath = getMeilisearchExePath()
dumpsPath = os.path.join(serverPath, "dumps")
dataMsPath = os.path.join(serverPath, "data.ms")
@@ -40,14 +79,7 @@ def ensure_meilisearch_installed() -> bool:
for p in existing_paths:
shutil.rmtree(p, ignore_errors=True)
- # Download MeiliSearch
- logger.debug("Downloading MeiliSearch...")
- subprocess.run(
- "curl -L https://install.meilisearch.com | sh",
- shell=True,
- check=True,
- cwd=serverPath,
- )
+ await download_meilisearch()
return False
@@ -66,7 +98,7 @@ async def check_meilisearch_running() -> bool:
if resp.status != "available":
return False
return True
- except Exception as e:
+ except Exception:
return False
except Exception:
return False
@@ -86,24 +118,21 @@ async def start_meilisearch():
"""
Starts the MeiliSearch server, wait for it.
"""
-
- # Doesn't work on windows for now
- if not os.name == "posix":
- return
-
serverPath = getServerFolderPath()
# Check if MeiliSearch is installed, if not download
- was_already_installed = ensure_meilisearch_installed()
+ was_already_installed = await ensure_meilisearch_installed()
# Check if MeiliSearch is running
if not await check_meilisearch_running() or not was_already_installed:
logger.debug("Starting MeiliSearch...")
+ binary_name = "meilisearch" if os.name == "nt" else "./meilisearch"
subprocess.Popen(
- ["./meilisearch", "--no-analytics"],
+ [binary_name, "--no-analytics"],
cwd=serverPath,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
close_fds=True,
start_new_session=True,
+ shell=True
)
diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts
index 7ca87768..3aa536d0 100644
--- a/extension/src/activation/environmentSetup.ts
+++ b/extension/src/activation/environmentSetup.ts
@@ -74,6 +74,14 @@ function serverVersionPath(): string {
return path.join(serverPath(), "server_version.txt");
}
+function serverBinaryPath(): string {
+ return path.join(
+ serverPath(),
+ "exe",
+ `run${os.platform() === "win32" ? ".exe" : ""}`
+ );
+}
+
export function getExtensionVersion() {
const extension = vscode.extensions.getExtension("continue.continue");
return extension?.packageJSON.version || "";
@@ -105,14 +113,7 @@ async function checkOrKillRunningServer(serverUrl: string): Promise<boolean> {
// Try again, on Windows. This time with taskkill
if (os.platform() === "win32") {
try {
- const exePath = path.join(
- getExtensionUri().fsPath,
- "server",
- "exe",
- "run.exe"
- );
-
- await runCommand(`taskkill /F /IM ${exePath}`);
+ await runCommand(`taskkill /F /IM run.exe`);
} catch (e: any) {
console.log(
"Failed to kill old server second time on windows with taskkill:",
@@ -126,14 +127,9 @@ async function checkOrKillRunningServer(serverUrl: string): Promise<boolean> {
fs.unlinkSync(serverVersionPath());
}
// Also delete the server binary
- const serverBinaryPath = path.join(
- getExtensionUri().fsPath,
- "server",
- "exe",
- `run${os.platform() === "win32" ? ".exe" : ""}`
- );
- if (fs.existsSync(serverBinaryPath)) {
- fs.unlinkSync(serverBinaryPath);
+ const serverBinary = serverBinaryPath();
+ if (fs.existsSync(serverBinary)) {
+ fs.unlinkSync(serverBinary);
}
}
@@ -213,12 +209,7 @@ export async function startContinuePythonServer(redownload: boolean = true) {
: "mac/run"
: "linux/run";
- const destination = path.join(
- getExtensionUri().fsPath,
- "server",
- "exe",
- `run${os.platform() === "win32" ? ".exe" : ""}`
- );
+ const destination = serverBinaryPath();
// First, check if the server is already downloaded
let shouldDownload = true;
diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts
index 94997d76..353584e9 100644
--- a/extension/src/continueIdeClient.ts
+++ b/extension/src/continueIdeClient.ts
@@ -272,40 +272,10 @@ class IdeProtocolClient {
break;
case "listDirectoryContents":
messenger.send("listDirectoryContents", {
- contents: (
- await vscode.workspace.fs.readDirectory(
- uriFromFilePath(data.directory)
- )
- )
- .map(([name, type]) => name)
- .filter((name) => {
- const DEFAULT_IGNORE_DIRS = [
- ".git",
- ".vscode",
- ".idea",
- ".vs",
- ".venv",
- "env",
- ".env",
- "node_modules",
- "dist",
- "build",
- "target",
- "out",
- "bin",
- ".pytest_cache",
- ".vscode-test",
- ".continue",
- "__pycache__",
- ];
- if (
- !DEFAULT_IGNORE_DIRS.some((dir) =>
- name.split(path.sep).includes(dir)
- )
- ) {
- return name;
- }
- }),
+ contents: await this.getDirectoryContents(
+ data.directory,
+ data.recursive || false
+ ),
});
break;
case "editFile":
@@ -562,6 +532,57 @@ class IdeProtocolClient {
});
}
+ async getDirectoryContents(
+ directory: string,
+ recursive: boolean
+ ): Promise<string[]> {
+ let nameAndType = (
+ await vscode.workspace.fs.readDirectory(uriFromFilePath(directory))
+ ).filter(([name, type]) => {
+ const DEFAULT_IGNORE_DIRS = [
+ ".git",
+ ".vscode",
+ ".idea",
+ ".vs",
+ ".venv",
+ "env",
+ ".env",
+ "node_modules",
+ "dist",
+ "build",
+ "target",
+ "out",
+ "bin",
+ ".pytest_cache",
+ ".vscode-test",
+ ".continue",
+ "__pycache__",
+ ];
+ if (
+ !DEFAULT_IGNORE_DIRS.some((dir) => name.split(path.sep).includes(dir))
+ ) {
+ return name;
+ }
+ });
+
+ let absolutePaths = nameAndType
+ .filter(([name, type]) => type === vscode.FileType.File)
+ .map(([name, type]) => path.join(directory, name));
+ if (recursive) {
+ for (const [name, type] of nameAndType) {
+ if (type === vscode.FileType.Directory) {
+ const subdirectory = path.join(directory, name);
+ const subdirectoryContents = await this.getDirectoryContents(
+ subdirectory,
+ recursive
+ );
+ absolutePaths = absolutePaths.concat(subdirectoryContents);
+ }
+ }
+ }
+ return absolutePaths;
+ }
+
async readFile(filepath: string): Promise<string> {
let contents: string | undefined;
if (typeof contents === "undefined") {