diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 16 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompters.py | 112 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 55 | 
5 files changed, 75 insertions, 123 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 23be8133..6a811412 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -19,8 +19,8 @@ class ContinueConfig(BaseModel):      steps_on_startup: Optional[Dict[str, Dict]] = {}      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", -                           "gpt-4", "starcoder"] = 'gpt-4' +    default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", +                           "gpt-4"] = 'gpt-4'      slash_commands: Optional[List[SlashCommand]] = [          # SlashCommand(          #     name="pytest", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 76f72d01..3e3c3bc5 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -45,6 +45,16 @@ class Models:          return asyncio.get_event_loop().run_until_complete(load_gpt35())      @cached_property +    def gpt3516k(self): +        async def load_gpt3516k(): +            api_key = await self.sdk.get_user_secret( +                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') +            if api_key == "": +                return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k") +            return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k") +        return asyncio.get_event_loop().run_until_complete(load_gpt3516k()) + +    @cached_property      def gpt4(self):          async def load_gpt4():              api_key = await self.sdk.get_user_secret( @@ -59,6 +69,8 @@ class Models:              return self.starcoder          elif model_name == "gpt-3.5-turbo":              return self.gpt35 +        elif model_name == "gpt-3.5-turbo-16k": +            return self.gpt3516k          elif model_name == "gpt-4":              return self.gpt4          else: @@ -174,10 +186,10 @@ class ContinueSDK(AbstractContinueSDK):          highlighted_code = await self.ide.getHighlightedCode()          if len(highlighted_code) == 0:              # Get the full contents of all open files -            files = await self.ide.getOpenFiles() +            files = await self.sdk.ide.getOpenFiles()              contents = {}              for file in files: -                contents[file] = await self.ide.readFile(file) +                contents[file] = await self.sdk.ide.readFile(file)              highlighted_code = [RangeInFile.from_entire_file(                  filepath, content) for filepath, content in contents.items()] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 180ea5f0..17d37035 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,12 +10,13 @@ from pydantic import BaseModel, validator  import tiktoken  MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4097, -    "gpt-4": 4097, +    "gpt-3.5-turbo": 4096, +    "gpt-3.5-turbo-16k": 16384, +    "gpt-4": 8192  }  DEFAULT_MAX_TOKENS = 2048  CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-4" +    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"  } @@ -24,7 +25,7 @@ class OpenAI(LLM):      completion_count: int = 0      default_model: str -    def __init__(self, api_key: str, default_model: str = "gpt-3.5-turbo", system_message: str = None): +    def __init__(self, api_key: str, default_model: str, system_message: str = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message @@ -51,7 +52,7 @@ class OpenAI(LLM):          return chat_history      def with_system_message(self, system_message: Union[str, None]): -        return OpenAI(api_key=self.api_key, system_message=system_message) +        return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message)      def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1 diff --git a/continuedev/src/continuedev/libs/llm/prompters.py b/continuedev/src/continuedev/libs/llm/prompters.py deleted file mode 100644 index 04e9885a..00000000 --- a/continuedev/src/continuedev/libs/llm/prompters.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any, Callable, List, Tuple, Union -from ..llm import LLM -from .openai import OpenAI - - -def cls_method_to_str(cls_name: str, init: str, method: str) -> str: -    """Convert class and method info to formatted code""" -    return f"""class {cls_name}: -{init} -{method}""" - - -# Prompter classes -class Prompter: -    def __init__(self, llm: LLM = None): -        if llm is None: -            self.llm = OpenAI() -        else: -            self.llm = llm - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        "Takes input and returns prompt, prefix, suffix" -        raise NotImplementedError - -    def complete(self, inp: Any, **kwargs) -> str: -        prompt, prefix, suffix = self._compile_prompt(inp) -        resp = self.llm.complete(prompt + prefix, suffix=suffix, **kwargs) -        return prefix + resp + (suffix or "") - -    def __call__(self, inp: Any, **kwargs) -> str: -        return self.complete(inp, **kwargs) - -    def parallel_complete(self, inps: List[Any]) -> List[str]: -        prompts = [] -        prefixes = [] -        suffixes = [] -        for inp in inps: -            prompt, prefix, suffix = self._compile_prompt(inp) -            prompts.append(prompt) -            prefixes.append(prefix) -            suffixes.append(suffix) - -        resps = self.llm.parallel_complete( -            [prompt + prefix for prompt, prefix in zip(prompts, prefixes)], suffixes=suffixes) -        return [prefix + resp + (suffix or "") for prefix, resp, suffix in zip(prefixes, resps, suffixes)] - - -class MixedPrompter(Prompter): -    def __init__(self, prompters: List[Prompter], router: Callable[[Any], int], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompters = prompters -        self.router = router - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        prompter = self.prompters[self.router(inp)] -        return prompter._compile_prompt(inp) - -    def complete(self, inp: Any, **kwargs) -> str: -        prompter = self.prompters[self.router(inp)] -        return prompter.complete(inp, **kwargs) - - -class SimplePrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], str], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        return self.prompt_fn(inp), "", None - - -class FormatStringPrompter(SimplePrompter): -    """Pass a formatted string, and the input should be a dict with the keys to format""" - -    def __init__(self, prompt: str, llm: LLM = None): -        super().__init__(lambda inp: prompt.format(**inp), llm=llm) - - -class BasicCommentPrompter(SimplePrompter): -    def __init__(self, comment: str, llm: LLM = None): -        super().__init__(lambda inp: f"""{inp} - -# {comment}""", llm=llm) - - -class EditPrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str]], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def complete(self, inp: str, **kwargs) -> str: -        inp, instruction = self.prompt_fn(inp) -        return self.llm.edit(inp, instruction, **kwargs) - -    def parallel_complete(self, inps: List[Any]) -> List[str]: -        prompts = [] -        instructions = [] -        for inp in inps: -            prompt, instruction = self.prompt_fn(inp) -            prompts.append(prompt) -            instructions.append(instruction) - -        return self.llm.parallel_edit(prompts, instructions) - - -class InsertPrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str, str]], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        return self.prompt_fn(inp) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index aee5bc1d..ee3ef9a7 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -174,16 +174,67 @@ class DefaultModelEditCodeStep(Step):          for rif in rif_with_contents:              await sdk.ide.setFileOpen(rif.filepath) +            model_to_use = sdk.config.default_model +              full_file_contents = await sdk.ide.readFile(rif.filepath) + +            full_file_contents_lst = full_file_contents.split("\n") + +            max_start_line = rif.range.start.line +            min_end_line = rif.range.end.line +            cur_start_line = 0 +            cur_end_line = len(full_file_contents_lst) + +            if sdk.config.default_model == "gpt-4": + +                total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) +                if total_tokens > sdk.models.gpt4.max_tokens: +                    while cur_end_line > min_end_line: +                        total_tokens -= len(full_file_contents_lst[cur_end_line]) +                        cur_end_line -= 1 +                        if total_tokens < sdk.models.gpt4.max_tokens: +                            break +                     +                    if total_tokens > sdk.models.gpt4.max_tokens: +                        while cur_start_line < max_start_line: +                            cur_start_line += 1 +                            total_tokens -= len(full_file_contents_lst[cur_start_line]) +                            if total_tokens < sdk.models.gpt4.max_tokens: +                                break + +            elif sdk.config.default_model == "gpt-3.5-turbo" or sdk.config.default_model == "gpt-3.5-turbo-16k": + +                if sdk.models.gpt35.count_tokens(full_file_contents) > sdk.models.gpt35.max_tokens: + +                    model_to_use = "gpt-3.5-turbo-16k" + +                    total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents) +                    if total_tokens > sdk.models.gpt3516k.max_tokens: +                        while cur_end_line > min_end_line: +                            total_tokens -= len(full_file_contents_lst[cur_end_line]) +                            cur_end_line -= 1 +                            if total_tokens < sdk.models.gpt4.max_tokens: +                                break +                         +                        if total_tokens > sdk.models.gpt3516k.max_tokens: +                            while cur_start_line < max_start_line: +                                total_tokens -= len(full_file_contents_lst[cur_start_line]) +                                cur_start_line += 1 +                                if total_tokens < sdk.models.gpt4.max_tokens: +                                    break +            else: +                raise Exception("Unknown default model") +              start_index, end_index = rif.range.indices_in_string(                  full_file_contents) +                          segs = [full_file_contents[:start_index], -                    full_file_contents[end_index:]] +                full_file_contents[end_index:]]              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) -            completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) +            completion = str(model_to_use.complete(prompt, with_history=await sdk.get_chat_context()))              eot_token = "<|endoftext|>"              completion = completion.removesuffix(eot_token) | 
