diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 12 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompters.py | 112 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 57 | 
5 files changed, 77 insertions, 125 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 1ee3a7f8..652320fb 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -25,8 +25,8 @@ class ContinueConfig(BaseModel):      disallowed_steps: Optional[List[str]] = []      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", -                           "gpt-4", "starcoder"] = 'gpt-4' +    default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", +                           "gpt-4"] = 'gpt-4'      slash_commands: Optional[List[SlashCommand]] = [          # SlashCommand(          #     name="pytest", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 4ab2f027..d6acc404 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -45,6 +45,16 @@ class Models:          return asyncio.get_event_loop().run_until_complete(load_gpt35())      @cached_property +    def gpt3516k(self): +        async def load_gpt3516k(): +            api_key = await self.sdk.get_user_secret( +                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') +            if api_key == "": +                return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k") +            return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k") +        return asyncio.get_event_loop().run_until_complete(load_gpt3516k()) + +    @cached_property      def gpt4(self):          async def load_gpt4():              api_key = await self.sdk.get_user_secret( @@ -59,6 +69,8 @@ class Models:              return self.starcoder          elif model_name == "gpt-3.5-turbo":              return self.gpt35 +        elif model_name == "gpt-3.5-turbo-16k": +            return self.gpt3516k          elif model_name == "gpt-4":              return self.gpt4          else: diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 345af7b4..22c28b20 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -9,13 +9,14 @@ from ..llm import LLM  from pydantic import BaseModel, validator  import tiktoken +DEFAULT_MAX_TOKENS = 2048  MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4097, -    "gpt-4": 4097, +    "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, +    "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, +    "gpt-4": 8192 - DEFAULT_MAX_TOKENS  } -DEFAULT_MAX_TOKENS = 2048  CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-4" +    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"  } @@ -24,7 +25,7 @@ class OpenAI(LLM):      completion_count: int = 0      default_model: str -    def __init__(self, api_key: str, default_model: str = "gpt-3.5-turbo", system_message: str = None): +    def __init__(self, api_key: str, default_model: str, system_message: str = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message @@ -32,6 +33,10 @@ class OpenAI(LLM):          openai.api_key = api_key      @cached_property +    def name(self): +        return self.default_model + +    @cached_property      def __encoding_for_model(self):          aliases = {              "gpt-3.5-turbo": "gpt3" @@ -76,7 +81,7 @@ class OpenAI(LLM):          return chat_history      def with_system_message(self, system_message: Union[str, None]): -        return OpenAI(api_key=self.api_key, system_message=system_message) +        return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message)      def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1 diff --git a/continuedev/src/continuedev/libs/llm/prompters.py b/continuedev/src/continuedev/libs/llm/prompters.py deleted file mode 100644 index 04e9885a..00000000 --- a/continuedev/src/continuedev/libs/llm/prompters.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any, Callable, List, Tuple, Union -from ..llm import LLM -from .openai import OpenAI - - -def cls_method_to_str(cls_name: str, init: str, method: str) -> str: -    """Convert class and method info to formatted code""" -    return f"""class {cls_name}: -{init} -{method}""" - - -# Prompter classes -class Prompter: -    def __init__(self, llm: LLM = None): -        if llm is None: -            self.llm = OpenAI() -        else: -            self.llm = llm - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        "Takes input and returns prompt, prefix, suffix" -        raise NotImplementedError - -    def complete(self, inp: Any, **kwargs) -> str: -        prompt, prefix, suffix = self._compile_prompt(inp) -        resp = self.llm.complete(prompt + prefix, suffix=suffix, **kwargs) -        return prefix + resp + (suffix or "") - -    def __call__(self, inp: Any, **kwargs) -> str: -        return self.complete(inp, **kwargs) - -    def parallel_complete(self, inps: List[Any]) -> List[str]: -        prompts = [] -        prefixes = [] -        suffixes = [] -        for inp in inps: -            prompt, prefix, suffix = self._compile_prompt(inp) -            prompts.append(prompt) -            prefixes.append(prefix) -            suffixes.append(suffix) - -        resps = self.llm.parallel_complete( -            [prompt + prefix for prompt, prefix in zip(prompts, prefixes)], suffixes=suffixes) -        return [prefix + resp + (suffix or "") for prefix, resp, suffix in zip(prefixes, resps, suffixes)] - - -class MixedPrompter(Prompter): -    def __init__(self, prompters: List[Prompter], router: Callable[[Any], int], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompters = prompters -        self.router = router - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        prompter = self.prompters[self.router(inp)] -        return prompter._compile_prompt(inp) - -    def complete(self, inp: Any, **kwargs) -> str: -        prompter = self.prompters[self.router(inp)] -        return prompter.complete(inp, **kwargs) - - -class SimplePrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], str], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        return self.prompt_fn(inp), "", None - - -class FormatStringPrompter(SimplePrompter): -    """Pass a formatted string, and the input should be a dict with the keys to format""" - -    def __init__(self, prompt: str, llm: LLM = None): -        super().__init__(lambda inp: prompt.format(**inp), llm=llm) - - -class BasicCommentPrompter(SimplePrompter): -    def __init__(self, comment: str, llm: LLM = None): -        super().__init__(lambda inp: f"""{inp} - -# {comment}""", llm=llm) - - -class EditPrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str]], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def complete(self, inp: str, **kwargs) -> str: -        inp, instruction = self.prompt_fn(inp) -        return self.llm.edit(inp, instruction, **kwargs) - -    def parallel_complete(self, inps: List[Any]) -> List[str]: -        prompts = [] -        instructions = [] -        for inp in inps: -            prompt, instruction = self.prompt_fn(inp) -            prompts.append(prompt) -            instructions.append(instruction) - -        return self.llm.parallel_edit(prompts, instructions) - - -class InsertPrompter(Prompter): -    def __init__(self, prompt_fn: Callable[[Any], Tuple[str, str, str]], llm: LLM = None): -        super().__init__(llm=llm) -        self.prompt_fn = prompt_fn - -    def _compile_prompt(self, inp: Any) -> Tuple[str, str, Union[str, None]]: -        return self.prompt_fn(inp) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 330f60ed..417398b7 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -10,6 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten  from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents  from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation  from ...core.main import Step, SequentialStep +from ...libs.llm.openai import MAX_TOKENS_FOR_MODEL  import difflib @@ -172,16 +173,62 @@ class DefaultModelEditCodeStep(Step):          for rif in rif_with_contents:              await sdk.ide.setFileOpen(rif.filepath) +            model_to_use = sdk.models.default +                          full_file_contents = await sdk.ide.readFile(rif.filepath) -            start_index, end_index = rif.range.indices_in_string( -                full_file_contents) -            segs = [full_file_contents[:start_index], -                    full_file_contents[end_index:]] + +            full_file_contents_lst = full_file_contents.split("\n") + +            max_start_line = rif.range.start.line +            min_end_line = rif.range.end.line +            cur_start_line = 0 +            cur_end_line = len(full_file_contents_lst) - 1 + +            def cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line): +                         +                if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                    while cur_end_line > min_end_line: +                        total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line]) +                        cur_end_line -= 1 +                        if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                            return cur_start_line, cur_end_line +                     +                    if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                        while cur_start_line < max_start_line: +                            cur_start_line += 1 +                            total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line]) +                            if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                                return cur_start_line, cur_end_line +                             +                return cur_start_line, cur_end_line + +            if model_to_use.name == "gpt-4": + +                total_tokens = model_to_use.count_tokens(full_file_contents) +                cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line) + +            elif model_to_use.name  == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k": + +                if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: + +                    model_to_use = sdk.models.gpt3516k +                    total_tokens = model_to_use.count_tokens(full_file_contents) +                    cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line) + +            else: + +                raise Exception("Unknown default model") +                       +            code_before = "".join(full_file_contents_lst[cur_start_line:max_start_line]) +            code_after = "".join(full_file_contents_lst[min_end_line:cur_end_line]) + +            segs = [code_before, code_after]              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) -            completion = str(await sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) +            completion = str(await model_to_use.complete(prompt, with_history=await sdk.get_chat_context())) +              eot_token = "<|endoftext|>"              completion = completion.removesuffix(eot_token) | 
