summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorTy Dunn <ty@tydunn.com>2023-06-15 18:27:42 -0700
committerTy Dunn <ty@tydunn.com>2023-06-15 18:27:42 -0700
commit6899bfb475ca1ec423e542540fa2788393e412ac (patch)
treee48b7c67f409a2123ec978b1454f761d02e9a04c /continuedev
parentcfd5227274048a372d134aa8d1f1826a0d3d8fff (diff)
downloadsncontinue-6899bfb475ca1ec423e542540fa2788393e412ac.tar.gz
sncontinue-6899bfb475ca1ec423e542540fa2788393e412ac.tar.bz2
sncontinue-6899bfb475ca1ec423e542540fa2788393e412ac.zip
algorithm for cutting lines implemented
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/config.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py16
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py11
-rw-r--r--continuedev/src/continuedev/libs/llm/prompters.py112
-rw-r--r--continuedev/src/continuedev/steps/core/core.py55
5 files changed, 75 insertions, 123 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 23be8133..6a811412 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -19,8 +19,8 @@ class ContinueConfig(BaseModel):
steps_on_startup: Optional[Dict[str, Dict]] = {}
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo",
- "gpt-4", "starcoder"] = 'gpt-4'
+ default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
+ "gpt-4"] = 'gpt-4'
slash_commands: Optional[List[SlashCommand]] = [
# SlashCommand(
# name="pytest",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 76f72d01..3e3c3bc5 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -45,6 +45,16 @@ class Models:
return asyncio.get_event_loop().run_until_complete(load_gpt35())
@cached_property
+ def gpt3516k(self):
+ async def load_gpt3516k():
+ api_key = await self.sdk.get_user_secret(
+ 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free')
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k")
+ return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k")
+ return asyncio.get_event_loop().run_until_complete(load_gpt3516k())
+
+ @cached_property
def gpt4(self):
async def load_gpt4():
api_key = await self.sdk.get_user_secret(
@@ -59,6 +69,8 @@ class Models:
return self.starcoder
elif model_name == "gpt-3.5-turbo":
return self.gpt35
+ elif model_name == "gpt-3.5-turbo-16k":
+ return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
else:
@@ -174,10 +186,10 @@ class ContinueSDK(AbstractContinueSDK):
highlighted_code = await self.ide.getHighlightedCode()
if len(highlighted_code) == 0:
# Get the full contents of all open files
- files = await self.ide.getOpenFiles()
+ files = await self.sdk.ide.getOpenFiles()
contents = {}
for file in files:
- contents[file] = await self.ide.readFile(file)
+ contents[file] = await self.sdk.ide.readFile(file)
highlighted_code = [RangeInFile.from_entire_file(
filepath, content) for filepath, content in contents.items()]
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 180ea5f0..17d37035 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -10,12 +10,13 @@ from pydantic import BaseModel, validator
import tiktoken
MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4097,
- "gpt-4": 4097,
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192
}
DEFAULT_MAX_TOKENS = 2048
CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-4"
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"
}
@@ -24,7 +25,7 @@ class OpenAI(LLM):
completion_count: int = 0
default_model: str
- def __init__(self, api_key: str, default_model: str = "gpt-3.5-turbo", system_message: str = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
@@ -51,7 +52,7 @@ class OpenAI(LLM):
return chat_history
def with_system_message(self, system_message: Union[str, None]):
- return OpenAI(api_key=self.api_key, system_message=system_message)
+ return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message)
def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
diff --git a/continuedev/src/continuedev/libs/llm/prompters.py b/continuedev/src/continuedev/libs/llm/prompters.py
deleted file mode 100644
index 04e9885a..00000000
--- a/continuedev/src/continuedev/libs/llm/prompters.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from typing import Any, Callable, List, Tuple, Union
-from ..llm import LLM
-from .openai import OpenAI
-
-
-def cls_method_to_str(cls_name: str, init: str, method: str) -> str:
- """Convert class and method info to formatted code"""
- return f"""class {cls_name}:
-{init}
-{method}"""
-
-
-# Prompter classes
-class Prompter:
- def __init__(self, llm: LLM = None):
- if llm is None:
- self.llm = OpenAI()
- else:
- self.llm = llm
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- "Takes input and returns prompt, prefix, suffix"
- raise NotImplementedError
-
- def complete(self, inp: Any, **kwargs) -> str:
- prompt, prefix, suffix = self._compile_prompt(inp)
- resp = self.llm.complete(prompt + prefix, suffix=suffix, **kwargs)
- return prefix + resp + (suffix or "")
-
- def __call__(self, inp: Any, **kwargs) -> str:
- return self.complete(inp, **kwargs)
-
- def parallel_complete(self, inps: List[Any]) -> List[str]:
- prompts = []
- prefixes = []
- suffixes = []
- for inp in inps:
- prompt, prefix, suffix = self._compile_prompt(inp)
- prompts.append(prompt)
- prefixes.append(prefix)
- suffixes.append(suffix)
-
- resps = self.llm.parallel_complete(
- [prompt + prefix for prompt, prefix in zip(prompts, prefixes)], suffixes=suffixes)
- return [prefix + resp + (suffix or "") for prefix, resp, suffix in zip(prefixes, resps, suffixes)]
-
-
-class MixedPrompter(Prompter):
- def __init__(self, prompters: List[Prompter], router: Callable[[Any], int], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompters = prompters
- self.router = router
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- prompter = self.prompters[self.router(inp)]
- return prompter._compile_prompt(inp)
-
- def complete(self, inp: Any, **kwargs) -> str:
- prompter = self.prompters[self.router(inp)]
- return prompter.complete(inp, **kwargs)
-
-
-class SimplePrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], str], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- return self.prompt_fn(inp), "", None
-
-
-class FormatStringPrompter(SimplePrompter):
- """Pass a formatted string, and the input should be a dict with the keys to format"""
-
- def __init__(self, prompt: str, llm: LLM = None):
- super().__init__(lambda inp: prompt.format(**inp), llm=llm)
-
-
-class BasicCommentPrompter(SimplePrompter):
- def __init__(self, comment: str, llm: LLM = None):
- super().__init__(lambda inp: f"""{inp}
-
-# {comment}""", llm=llm)
-
-
-class EditPrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str]], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def complete(self, inp: str, **kwargs) -> str:
- inp, instruction = self.prompt_fn(inp)
- return self.llm.edit(inp, instruction, **kwargs)
-
- def parallel_complete(self, inps: List[Any]) -> List[str]:
- prompts = []
- instructions = []
- for inp in inps:
- prompt, instruction = self.prompt_fn(inp)
- prompts.append(prompt)
- instructions.append(instruction)
-
- return self.llm.parallel_edit(prompts, instructions)
-
-
-class InsertPrompter(Prompter):
- def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str, str]], llm: LLM = None):
- super().__init__(llm=llm)
- self.prompt_fn = prompt_fn
-
- def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]:
- return self.prompt_fn(inp)
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index aee5bc1d..ee3ef9a7 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -174,16 +174,67 @@ class DefaultModelEditCodeStep(Step):
for rif in rif_with_contents:
await sdk.ide.setFileOpen(rif.filepath)
+ model_to_use = sdk.config.default_model
+
full_file_contents = await sdk.ide.readFile(rif.filepath)
+
+ full_file_contents_lst = full_file_contents.split("\n")
+
+ max_start_line = rif.range.start.line
+ min_end_line = rif.range.end.line
+ cur_start_line = 0
+ cur_end_line = len(full_file_contents_lst)
+
+ if sdk.config.default_model == "gpt-4":
+
+ total_tokens = sdk.models.gpt4.count_tokens(full_file_contents)
+ if total_tokens > sdk.models.gpt4.max_tokens:
+ while cur_end_line > min_end_line:
+ total_tokens -= len(full_file_contents_lst[cur_end_line])
+ cur_end_line -= 1
+ if total_tokens < sdk.models.gpt4.max_tokens:
+ break
+
+ if total_tokens > sdk.models.gpt4.max_tokens:
+ while cur_start_line < max_start_line:
+ cur_start_line += 1
+ total_tokens -= len(full_file_contents_lst[cur_start_line])
+ if total_tokens < sdk.models.gpt4.max_tokens:
+ break
+
+ elif sdk.config.default_model == "gpt-3.5-turbo" or sdk.config.default_model == "gpt-3.5-turbo-16k":
+
+ if sdk.models.gpt35.count_tokens(full_file_contents) > sdk.models.gpt35.max_tokens:
+
+ model_to_use = "gpt-3.5-turbo-16k"
+
+ total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents)
+ if total_tokens > sdk.models.gpt3516k.max_tokens:
+ while cur_end_line > min_end_line:
+ total_tokens -= len(full_file_contents_lst[cur_end_line])
+ cur_end_line -= 1
+ if total_tokens < sdk.models.gpt4.max_tokens:
+ break
+
+ if total_tokens > sdk.models.gpt3516k.max_tokens:
+ while cur_start_line < max_start_line:
+ total_tokens -= len(full_file_contents_lst[cur_start_line])
+ cur_start_line += 1
+ if total_tokens < sdk.models.gpt4.max_tokens:
+ break
+ else:
+ raise Exception("Unknown default model")
+
start_index, end_index = rif.range.indices_in_string(
full_file_contents)
+
segs = [full_file_contents[:start_index],
- full_file_contents[end_index:]]
+ full_file_contents[end_index:]]
prompt = self._prompt.format(
code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1])
- completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))
+ completion = str(model_to_use.complete(prompt, with_history=await sdk.get_chat_context()))
eot_token = "<|endoftext|>"
completion = completion.removesuffix(eot_token)