summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-09 13:09:34 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-09 13:09:34 -0700
commitb2d621bb075ccfb73c4662406df2974818744436 (patch)
tree4e1b23a961bee94a2fad4fe318c1ad31fc3d6f83
parent924fc27e53bd2503a3f6d3b49a4d5c02b8ace66d (diff)
downloadsncontinue-b2d621bb075ccfb73c4662406df2974818744436.tar.gz
sncontinue-b2d621bb075ccfb73c4662406df2974818744436.tar.bz2
sncontinue-b2d621bb075ccfb73c4662406df2974818744436.zip
expand max_tokens for large highlighted ranges
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
-rw-r--r--continuedev/src/continuedev/steps/core/core.py23
-rw-r--r--extension/package-lock.json4
-rw-r--r--extension/package.json2
-rw-r--r--extension/src/activation/environmentSetup.ts9
7 files changed, 34 insertions, 20 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c4e4139f..f0877d90 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -37,7 +37,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +58,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +69,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args,
)).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 05ece394..eab6e441 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -50,7 +50,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, None, functions=args.get("functions", None))
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -74,7 +74,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, prompt, functions=args.get("functions", None))
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 8b06fef9..73be0717 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -76,14 +76,14 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
return chat_history
-def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
prompt_tokens = count_tokens(model, prompt)
if functions is not None:
for function in functions:
prompt_tokens += count_tokens(model, json.dumps(function))
msgs = prune_chat_history(model,
- msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + DEFAULT_MAX_TOKENS + count_tokens(model, system_message))
+ msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message))
history = []
if system_message:
history.append({
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 10853828..4b35a758 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -181,15 +181,22 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.gpt4
+ max_tokens = DEFAULT_MAX_TOKENS
- BUFFER_FOR_FUNCTIONS = 400
- total_tokens = model_to_use.count_tokens(
- full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + DEFAULT_MAX_TOKENS
-
- TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1000
+ TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**"
+ # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation
+ # Increase max_tokens to be double the size of the range
+ # But don't exceed twice default max tokens
+ max_tokens = int(min(model_to_use.count_tokens(
+ rif.contents), DEFAULT_MAX_TOKENS) * 2.5)
+
+ BUFFER_FOR_FUNCTIONS = 400
+ total_tokens = model_to_use.count_tokens(
+ full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + max_tokens
+
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
@@ -252,7 +259,7 @@ class DefaultModelEditCodeStep(Step):
file_suffix = "\n" + file_suffix
rif.contents = rif.contents[:-1]
- return file_prefix, rif.contents, file_suffix, model_to_use
+ return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens
def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str:
prompt = self._prompt
@@ -289,7 +296,7 @@ class DefaultModelEditCodeStep(Step):
await sdk.ide.saveFile(rif.filepath)
full_file_contents = await sdk.ide.readFile(rif.filepath)
- file_prefix, contents, file_suffix, model_to_use = await self.get_prompt_parts(
+ file_prefix, contents, file_suffix, model_to_use, max_tokens = await self.get_prompt_parts(
rif, sdk, full_file_contents)
contents, common_whitespace = dedent_and_get_common_whitespace(
contents)
@@ -435,7 +442,7 @@ class DefaultModelEditCodeStep(Step):
completion_lines_covered = 0
repeating_file_suffix = False
line_below_highlighted_range = file_suffix.lstrip().split("\n")[0]
- async for chunk in model_to_use.stream_chat(messages, temperature=0):
+ async for chunk in model_to_use.stream_chat(messages, temperature=0, max_tokens=max_tokens):
# Stop early if it is repeating the file_suffix or the step was deleted
if repeating_file_suffix:
break
diff --git a/extension/package-lock.json b/extension/package-lock.json
index 3c0d6e3e..22f8b492 100644
--- a/extension/package-lock.json
+++ b/extension/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "continue",
- "version": "0.0.139",
+ "version": "0.0.141",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "continue",
- "version": "0.0.139",
+ "version": "0.0.141",
"license": "Apache-2.0",
"dependencies": {
"@electron/rebuild": "^3.2.10",
diff --git a/extension/package.json b/extension/package.json
index e1de3b73..beb675b3 100644
--- a/extension/package.json
+++ b/extension/package.json
@@ -14,7 +14,7 @@
"displayName": "Continue",
"pricing": "Free",
"description": "The open-source coding autopilot",
- "version": "0.0.139",
+ "version": "0.0.141",
"publisher": "Continue",
"engines": {
"vscode": "^1.67.0"
diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts
index 90ec9259..714080e3 100644
--- a/extension/src/activation/environmentSetup.ts
+++ b/extension/src/activation/environmentSetup.ts
@@ -324,7 +324,14 @@ export async function startContinuePythonServer() {
}
}
console.log("Killing old server...");
- await fkill(":65432");
+ try {
+ await fkill(":65432");
+ } catch (e) {
+ console.log(
+ "Failed to kill old server, likely because it didn't exist:",
+ e
+ );
+ }
}
// Do this after above check so we don't have to waste time setting up the env