summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTy Dunn <ty@tydunn.com>2023-06-16 15:56:14 -0700
committerGitHub <noreply@github.com>2023-06-16 15:56:14 -0700
commit726c1da91d0d08689043da8e0b854e39121883f9 (patch)
treea91625915ea31f00b0991ff3f333925e17bab5a4
parent55611ef0b6ca014ff091a1cd18fb749ab210b3ec (diff)
parentc196431ff1de014827066e3a04c39438c34ebe3d (diff)
downloadsncontinue-726c1da91d0d08689043da8e0b854e39121883f9.tar.gz
sncontinue-726c1da91d0d08689043da8e0b854e39121883f9.tar.bz2
sncontinue-726c1da91d0d08689043da8e0b854e39121883f9.zip
Merge pull request #100 from continuedev/too-large
better handling files that are too large for context windows
-rw-r--r--continuedev/src/continuedev/core/config.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py17
-rw-r--r--continuedev/src/continuedev/libs/llm/prompters.py112
-rw-r--r--continuedev/src/continuedev/steps/core/core.py57
5 files changed, 77 insertions, 125 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 8f703758..859c6188 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -25,8 +25,8 @@ class ContinueConfig(BaseModel):
disallowed_steps: Optional[List[str]] = []
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo",
- "gpt-4", "starcoder"] = 'gpt-4'
+ default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
+ "gpt-4"] = 'gpt-4'
slash_commands: Optional[List[SlashCommand]] = [
# SlashCommand(
# name="pytest",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 7639d010..7159beaa 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -45,6 +45,16 @@ class Models:
return asyncio.get_event_loop().run_until_complete(load_gpt35())
@cached_property
+ def gpt3516k(self):
+ async def load_gpt3516k():
+ api_key = await self.sdk.get_user_secret(
+ 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free')
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k")
+ return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k")
+ return asyncio.get_event_loop().run_until_complete(load_gpt3516k())
+
+ @cached_property
def gpt4(self):
async def load_gpt4():
api_key = await self.sdk.get_user_secret(
@@ -59,6 +69,8 @@ class Models:
return self.starcoder
elif model_name == "gpt-3.5-turbo":
return self.gpt35
+ elif model_name == "gpt-3.5-turbo-16k":
+ return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
else:
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 345af7b4..22c28b20 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -9,13 +9,14 @@ from ..llm import LLM
from pydantic import BaseModel, validator
import tiktoken
+DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4097,
- "gpt-4": 4097,
+ "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS,
+ "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS,
+ "gpt-4": 8192 - DEFAULT_MAX_TOKENS
}
-DEFAULT_MAX_TOKENS = 2048
CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-4"
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"
}
@@ -24,7 +25,7 @@ class OpenAI(LLM):
completion_count: int = 0
default_model: str
- def __init__(self, api_key: str, default_model: str = "gpt-3.5-turbo", system_message: str = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
@@ -32,6 +33,10 @@ class OpenAI(LLM):
openai.api_key = api_key
@cached_property
+ def name(self):
+ return self.default_model
+
+ @cached_property
def __encoding_for_model(self):
aliases = {
"gpt-3.5-turbo": "gpt3"
@@ -76,7 +81,7 @@ class OpenAI(LLM):
return chat_history
def with_system_message(self, system_message: Union[str, None]):
- return OpenAI(api_key=self.api_key, system_message=system_message)
+ return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message)
def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
diff --git a/continuedev/src/continuedev/libs/llm/prompters.py b/continuedev/src/continuedev/libs/llm/prompters.py
deleted file mode 100644
index 04e9885a..00000000
--- a/continuedev/src/continuedev/libs/llm/prompters.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from typing import Any, Callable, List, Tuple, Union
-from ..llm import LLM
-from .openai import OpenAI
-
-
-def cls_method_to_str(cls_name: str, init: str, method: str) -> str:
- """Convert class and method info to formatted code"""
- return f"""class {cls_name}:
-{init}
-{method}"""
-
-
-# Prompter classes
-class Prompter:
- def __init__(self, llm: LLM = None):
- if llm is None:
- self.llm = OpenAI()
- else:
- self.llm = llm
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- "Takes input and returns prompt, prefix, suffix"
- raise NotImplementedError
-
- def complete(self, inp: Any, **kwargs) -> str:
- prompt, prefix, suffix = self._compile_prompt(inp)
- resp = self.llm.complete(prompt + prefix, suffix=suffix, **kwargs)
- return prefix + resp + (suffix or "")
-
- def __call__(self, inp: Any, **kwargs) -> str:
- return self.complete(inp, **kwargs)
-
- def parallel_complete(self, inps: List[Any]) -> List[str]:
- prompts = []
- prefixes = []
- suffixes = []
- for inp in inps:
- prompt, prefix, suffix = self._compile_prompt(inp)
- prompts.append(prompt)
- prefixes.append(prefix)
- suffixes.append(suffix)
-
- resps = self.llm.parallel_complete(
- [prompt + prefix for prompt, prefix in zip(prompts, prefixes)], suffixes=suffixes)
- return [prefix + resp + (suffix or "") for prefix, resp, suffix in zip(prefixes, resps, suffixes)]
-
-
-class MixedPrompter(Prompter):
- def __init__(self, prompters: List[Prompter], router: Callable[[Any], int], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompters = prompters
- self.router = router
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- prompter = self.prompters[self.router(inp)]
- return prompter._compile_prompt(inp)
-
- def complete(self, inp: Any, **kwargs) -> str:
- prompter = self.prompters[self.router(inp)]
- return prompter.complete(inp, **kwargs)
-
-
-class SimplePrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], str], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- return self.prompt_fn(inp), "", None
-
-
-class FormatStringPrompter(SimplePrompter):
- """Pass a formatted string, and the input should be a dict with the keys to format"""
-
- def __init__(self, prompt: str, llm: LLM = None):
- super().__init__(lambda inp: prompt.format(**inp), llm=llm)
-
-
-class BasicCommentPrompter(SimplePrompter):
- def __init__(self, comment: str, llm: LLM = None):
- super().__init__(lambda inp: f"""{inp}
-
-# {comment}""", llm=llm)
-
-
-class EditPrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str]], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def complete(self, inp: str, **kwargs) -> str:
- inp, instruction = self.prompt_fn(inp)
- return self.llm.edit(inp, instruction, **kwargs)
-
- def parallel_complete(self, inps: List[Any]) -> List[str]:
- prompts = []
- instructions = []
- for inp in inps:
- prompt, instruction = self.prompt_fn(inp)
- prompts.append(prompt)
- instructions.append(instruction)
-
- return self.llm.parallel_edit(prompts, instructions)
-
-
-class InsertPrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str, str]], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- return self.prompt_fn(inp)
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 59af5f38..d580e2d2 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -10,6 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten
from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ...core.main import Step, SequentialStep
+from ...libs.llm.openai import MAX_TOKENS_FOR_MODEL
import difflib
@@ -172,16 +173,62 @@ class DefaultModelEditCodeStep(Step):
for rif in rif_with_contents:
await sdk.ide.setFileOpen(rif.filepath)
+ model_to_use = sdk.models.default
+
full_file_contents = await sdk.ide.readFile(rif.filepath)
- start_index, end_index = rif.range.indices_in_string(
- full_file_contents)
- segs = [full_file_contents[:start_index],
- full_file_contents[end_index:]]
+
+ full_file_contents_lst = full_file_contents.split("\n")
+
+ max_start_line = rif.range.start.line
+ min_end_line = rif.range.end.line
+ cur_start_line = 0
+ cur_end_line = len(full_file_contents_lst) - 1
+
+ def cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line):
+
+ if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ while cur_end_line > min_end_line:
+ total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line])
+ cur_end_line -= 1
+ if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ return cur_start_line, cur_end_line
+
+ if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ while cur_start_line < max_start_line:
+ cur_start_line += 1
+ total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line])
+ if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ return cur_start_line, cur_end_line
+
+ return cur_start_line, cur_end_line
+
+ if model_to_use.name == "gpt-4":
+
+ total_tokens = model_to_use.count_tokens(full_file_contents)
+ cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line)
+
+ elif model_to_use.name == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k":
+
+ if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
+
+ model_to_use = sdk.models.gpt3516k
+ total_tokens = model_to_use.count_tokens(full_file_contents)
+ cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line)
+
+ else:
+
+ raise Exception("Unknown default model")
+
+ code_before = "".join(full_file_contents_lst[cur_start_line:max_start_line])
+ code_after = "".join(full_file_contents_lst[min_end_line:cur_end_line])
+
+ segs = [code_before, code_after]
prompt = self._prompt.format(
code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1])
- completion = str(await sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))
+ completion = str(await model_to_use.complete(prompt, with_history=await sdk.get_chat_context()))
+
eot_token = "<|endoftext|>"
completion = completion.removesuffix(eot_token)