diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 12 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 44 | 
2 files changed, 30 insertions, 26 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 17d37035..136e86b4 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -9,12 +9,12 @@ from ..llm import LLM  from pydantic import BaseModel, validator  import tiktoken +DEFAULT_MAX_TOKENS = 2048  MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4096, -    "gpt-3.5-turbo-16k": 16384, -    "gpt-4": 8192 +    "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, +    "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, +    "gpt-4": 8192 - DEFAULT_MAX_TOKENS  } -DEFAULT_MAX_TOKENS = 2048  CHAT_MODELS = {      "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"  } @@ -33,6 +33,10 @@ class OpenAI(LLM):          openai.api_key = api_key      @cached_property +    def name(self): +        return self.default_model + +    @cached_property      def __encoding_for_model(self):          aliases = {              "gpt-3.5-turbo": "gpt3" diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index ec0007a2..de6fa29a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -175,8 +175,8 @@ class DefaultModelEditCodeStep(Step):          for rif in rif_with_contents:              await sdk.ide.setFileOpen(rif.filepath) -            model_to_use = sdk.config.default_model - +            model_to_use = sdk.models.default +                          full_file_contents = await sdk.ide.readFile(rif.filepath)              full_file_contents_lst = full_file_contents.split("\n") @@ -184,45 +184,45 @@ class DefaultModelEditCodeStep(Step):              max_start_line = rif.range.start.line              min_end_line = rif.range.end.line              cur_start_line = 0 -            cur_end_line = len(full_file_contents_lst) +            cur_end_line = len(full_file_contents_lst) - 1 -            def cut_context(model_to_use, total_tokens): -                -               if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: +            def cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line): +                         +                if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:                      while cur_end_line > min_end_line: -                        total_tokens -= len(full_file_contents_lst[cur_end_line]) +                        total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line])                          cur_end_line -= 1 -                        if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: +                        if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:                              return cur_start_line, cur_end_line -                    if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: +                    if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:                          while cur_start_line < max_start_line:                              cur_start_line += 1 -                            total_tokens -= len(full_file_contents_lst[cur_start_line]) -                            if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: +                            total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line]) +                            if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:                                  return cur_start_line, cur_end_line -                    return cur_start_line, cur_end_line -               else: +                                              return cur_start_line, cur_end_line -            if model_to_use == "gpt-4": +            if model_to_use.name == "gpt-4": -                total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) -                cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) +                total_tokens = model_to_use.count_tokens(full_file_contents) +                cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line) -            elif model_to_use == "gpt-3.5-turbo" or model_to_use == "gpt-3.5-turbo-16k": +            elif model_to_use.name  == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k":                  if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: -                    model_to_use = "gpt-3.5-turbo-16k" -                    total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents) -                    cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) +                    model_to_use = sdk.models.gpt3516k +                    total_tokens = model_to_use.count_tokens(full_file_contents) +                    cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line)              else: +                  raise Exception("Unknown default model") -            code_before = "".join(full_file_contents_lst[cur_end_line:]) -            code_after = "".join(full_file_contents_lst[:cur_start_line]) +            code_before = "".join(full_file_contents_lst[cur_start_line:max_start_line]) +            code_after = "".join(full_file_contents_lst[min_end_line:cur_end_line])              segs = [code_before, code_after] | 
