diff options
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 88dc8d72..ec0007a2 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -10,6 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation from ...core.main import Step, SequentialStep +from ...libs.llm.openai import MAX_TOKENS_FOR_MODEL import difflib @@ -185,48 +186,43 @@ class DefaultModelEditCodeStep(Step): cur_start_line = 0 cur_end_line = len(full_file_contents_lst) - if sdk.config.default_model == "gpt-4": - - total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) - if total_tokens > sdk.models.gpt4.max_tokens: + def cut_context(model_to_use, total_tokens): + + if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: while cur_end_line > min_end_line: total_tokens -= len(full_file_contents_lst[cur_end_line]) cur_end_line -= 1 - if total_tokens < sdk.models.gpt4.max_tokens: - break + if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: + return cur_start_line, cur_end_line - if total_tokens > sdk.models.gpt4.max_tokens: + if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: while cur_start_line < max_start_line: cur_start_line += 1 total_tokens -= len(full_file_contents_lst[cur_start_line]) - if total_tokens < sdk.models.gpt4.max_tokens: - break + if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: + return cur_start_line, cur_end_line + return cur_start_line, cur_end_line + else: + return cur_start_line, cur_end_line - elif sdk.config.default_model == "gpt-3.5-turbo" or sdk.config.default_model == "gpt-3.5-turbo-16k": + if model_to_use == "gpt-4": - if sdk.models.gpt35.count_tokens(full_file_contents) > sdk.models.gpt35.max_tokens: + total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) + cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) - model_to_use = "gpt-3.5-turbo-16k" + elif model_to_use == "gpt-3.5-turbo" or model_to_use == "gpt-3.5-turbo-16k": + if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: + + model_to_use = "gpt-3.5-turbo-16k" total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents) - if total_tokens > sdk.models.gpt3516k.max_tokens: - while cur_end_line > min_end_line: - total_tokens -= len(full_file_contents_lst[cur_end_line]) - cur_end_line -= 1 - if total_tokens < sdk.models.gpt4.max_tokens: - break - - if total_tokens > sdk.models.gpt3516k.max_tokens: - while cur_start_line < max_start_line: - total_tokens -= len(full_file_contents_lst[cur_start_line]) - cur_start_line += 1 - if total_tokens < sdk.models.gpt4.max_tokens: - break + cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) + else: raise Exception("Unknown default model") - code_before = "".join(full_file_contents_lst[min_end_line:]) - code_after = "".join(full_file_contents_lst[:max_start_line]) + code_before = "".join(full_file_contents_lst[cur_end_line:]) + code_after = "".join(full_file_contents_lst[:cur_start_line]) segs = [code_before, code_after] |