diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 16 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/prompters.py | 112 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 55 |
5 files changed, 75 insertions, 123 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 23be8133..6a811412 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -19,8 +19,8 @@ class ContinueConfig(BaseModel): steps_on_startup: Optional[Dict[str, Dict]] = {} server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", - "gpt-4", "starcoder"] = 'gpt-4' + default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", + "gpt-4"] = 'gpt-4' slash_commands: Optional[List[SlashCommand]] = [ # SlashCommand( # name="pytest", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 76f72d01..3e3c3bc5 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -45,6 +45,16 @@ class Models: return asyncio.get_event_loop().run_until_complete(load_gpt35()) @cached_property + def gpt3516k(self): + async def load_gpt3516k(): + api_key = await self.sdk.get_user_secret( + 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k") + return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k") + return asyncio.get_event_loop().run_until_complete(load_gpt3516k()) + + @cached_property def gpt4(self): async def load_gpt4(): api_key = await self.sdk.get_user_secret( @@ -59,6 +69,8 @@ class Models: return self.starcoder elif model_name == "gpt-3.5-turbo": return self.gpt35 + elif model_name == "gpt-3.5-turbo-16k": + return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 else: @@ -174,10 +186,10 @@ class ContinueSDK(AbstractContinueSDK): highlighted_code = await self.ide.getHighlightedCode() if len(highlighted_code) == 0: # Get the full contents of all open files - files = await self.ide.getOpenFiles() + files = await self.sdk.ide.getOpenFiles() contents = {} for file in files: - contents[file] = await self.ide.readFile(file) + contents[file] = await self.sdk.ide.readFile(file) highlighted_code = [RangeInFile.from_entire_file( filepath, content) for filepath, content in contents.items()] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 180ea5f0..17d37035 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,12 +10,13 @@ from pydantic import BaseModel, validator import tiktoken MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4097, - "gpt-4": 4097, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192 } DEFAULT_MAX_TOKENS = 2048 CHAT_MODELS = { - "gpt-3.5-turbo", "gpt-4" + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4" } @@ -24,7 +25,7 @@ class OpenAI(LLM): completion_count: int = 0 default_model: str - def __init__(self, api_key: str, default_model: str = "gpt-3.5-turbo", system_message: str = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message @@ -51,7 +52,7 @@ class OpenAI(LLM): return chat_history def with_system_message(self, system_message: Union[str, None]): - return OpenAI(api_key=self.api_key, system_message=system_message) + return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message) def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: self.completion_count += 1 diff --git a/continuedev/src/continuedev/libs/llm/prompters.py b/continuedev/src/continuedev/libs/llm/prompters.py deleted file mode 100644 index 04e9885a..00000000 --- a/continuedev/src/continuedev/libs/llm/prompters.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any, Callable, List, Tuple, Union -from ..llm import LLM -from .openai import OpenAI - - -def cls_method_to_str(cls_name: str, init: str, method: str) -> str: - """Convert class and method info to formatted code""" - return f"""class {cls_name}: -{init} -{method}""" - - -# Prompter classes -class Prompter: - def __init__(self, llm: LLM = None): - if llm is None: - self.llm = OpenAI() - else: - self.llm = llm - - def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: - "Takes input and returns prompt, prefix, suffix" - raise NotImplementedError - - def complete(self, inp: Any, **kwargs) -> str: - prompt, prefix, suffix = self._compile_prompt(inp) - resp = self.llm.complete(prompt + prefix, suffix=suffix, **kwargs) - return prefix + resp + (suffix or "") - - def __call__(self, inp: Any, **kwargs) -> str: - return self.complete(inp, **kwargs) - - def parallel_complete(self, inps: List[Any]) -> List[str]: - prompts = [] - prefixes = [] - suffixes = [] - for inp in inps: - prompt, prefix, suffix = self._compile_prompt(inp) - prompts.append(prompt) - prefixes.append(prefix) - suffixes.append(suffix) - - resps = self.llm.parallel_complete( - [prompt + prefix for prompt, prefix in zip(prompts, prefixes)], suffixes=suffixes) - return [prefix + resp + (suffix or "") for prefix, resp, suffix in zip(prefixes, resps, suffixes)] - - -class MixedPrompter(Prompter): - def __init__(self, prompters: List[Prompter], router: Callable[[Any], int], llm: LLM = None): - super().__init__(llm=llm) - self.prompters = prompters - self.router = router - - def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: - prompter = self.prompters[self.router(inp)] - return prompter._compile_prompt(inp) - - def complete(self, inp: Any, **kwargs) -> str: - prompter = self.prompters[self.router(inp)] - return prompter.complete(inp, **kwargs) - - -class SimplePrompter(Prompter): - def __init__(self, prompt_fn: Callable[[Any], str], llm: LLM = None): - super().__init__(llm=llm) - self.prompt_fn = prompt_fn - - def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: - return self.prompt_fn(inp), "", None - - -class FormatStringPrompter(SimplePrompter): - """Pass a formatted string, and the input should be a dict with the keys to format""" - - def __init__(self, prompt: str, llm: LLM = None): - super().__init__(lambda inp: prompt.format(**inp), llm=llm) - - -class BasicCommentPrompter(SimplePrompter): - def __init__(self, comment: str, llm: LLM = None): - super().__init__(lambda inp: f"""{inp} - -# {comment}""", llm=llm) - - -class EditPrompter(Prompter): - def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str]], llm: LLM = None): - super().__init__(llm=llm) - self.prompt_fn = prompt_fn - - def complete(self, inp: str, **kwargs) -> str: - inp, instruction = self.prompt_fn(inp) - return self.llm.edit(inp, instruction, **kwargs) - - def parallel_complete(self, inps: List[Any]) -> List[str]: - prompts = [] - instructions = [] - for inp in inps: - prompt, instruction = self.prompt_fn(inp) - prompts.append(prompt) - instructions.append(instruction) - - return self.llm.parallel_edit(prompts, instructions) - - -class InsertPrompter(Prompter): - def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str, str]], llm: LLM = None): - super().__init__(llm=llm) - self.prompt_fn = prompt_fn - - def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: - return self.prompt_fn(inp) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index aee5bc1d..ee3ef9a7 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -174,16 +174,67 @@ class DefaultModelEditCodeStep(Step): for rif in rif_with_contents: await sdk.ide.setFileOpen(rif.filepath) + model_to_use = sdk.config.default_model + full_file_contents = await sdk.ide.readFile(rif.filepath) + + full_file_contents_lst = full_file_contents.split("\n") + + max_start_line = rif.range.start.line + min_end_line = rif.range.end.line + cur_start_line = 0 + cur_end_line = len(full_file_contents_lst) + + if sdk.config.default_model == "gpt-4": + + total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) + if total_tokens > sdk.models.gpt4.max_tokens: + while cur_end_line > min_end_line: + total_tokens -= len(full_file_contents_lst[cur_end_line]) + cur_end_line -= 1 + if total_tokens < sdk.models.gpt4.max_tokens: + break + + if total_tokens > sdk.models.gpt4.max_tokens: + while cur_start_line < max_start_line: + cur_start_line += 1 + total_tokens -= len(full_file_contents_lst[cur_start_line]) + if total_tokens < sdk.models.gpt4.max_tokens: + break + + elif sdk.config.default_model == "gpt-3.5-turbo" or sdk.config.default_model == "gpt-3.5-turbo-16k": + + if sdk.models.gpt35.count_tokens(full_file_contents) > sdk.models.gpt35.max_tokens: + + model_to_use = "gpt-3.5-turbo-16k" + + total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents) + if total_tokens > sdk.models.gpt3516k.max_tokens: + while cur_end_line > min_end_line: + total_tokens -= len(full_file_contents_lst[cur_end_line]) + cur_end_line -= 1 + if total_tokens < sdk.models.gpt4.max_tokens: + break + + if total_tokens > sdk.models.gpt3516k.max_tokens: + while cur_start_line < max_start_line: + total_tokens -= len(full_file_contents_lst[cur_start_line]) + cur_start_line += 1 + if total_tokens < sdk.models.gpt4.max_tokens: + break + else: + raise Exception("Unknown default model") + start_index, end_index = rif.range.indices_in_string( full_file_contents) + segs = [full_file_contents[:start_index], - full_file_contents[end_index:]] + full_file_contents[end_index:]] prompt = self._prompt.format( code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) - completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) + completion = str(model_to_use.complete(prompt, with_history=await sdk.get_chat_context())) eot_token = "<|endoftext|>" completion = completion.removesuffix(eot_token) |