summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/steps/core/core.py50
1 files changed, 23 insertions, 27 deletions
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 88dc8d72..ec0007a2 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -10,6 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten
from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ...core.main import Step, SequentialStep
+from ...libs.llm.openai import MAX_TOKENS_FOR_MODEL
import difflib
@@ -185,48 +186,43 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst)
- if sdk.config.default_model == "gpt-4":
-
- total_tokens = sdk.models.gpt4.count_tokens(full_file_contents)
- if total_tokens > sdk.models.gpt4.max_tokens:
+ def cut_context(model_to_use, total_tokens):
+
+ if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]:
while cur_end_line > min_end_line:
total_tokens -= len(full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < sdk.models.gpt4.max_tokens:
- break
+ if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]:
+ return cur_start_line, cur_end_line
- if total_tokens > sdk.models.gpt4.max_tokens:
+ if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= len(full_file_contents_lst[cur_start_line])
- if total_tokens < sdk.models.gpt4.max_tokens:
- break
+ if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]:
+ return cur_start_line, cur_end_line
+ return cur_start_line, cur_end_line
+ else:
+ return cur_start_line, cur_end_line
- elif sdk.config.default_model == "gpt-3.5-turbo" or sdk.config.default_model == "gpt-3.5-turbo-16k":
+ if model_to_use == "gpt-4":
- if sdk.models.gpt35.count_tokens(full_file_contents) > sdk.models.gpt35.max_tokens:
+ total_tokens = sdk.models.gpt4.count_tokens(full_file_contents)
+ cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens)
- model_to_use = "gpt-3.5-turbo-16k"
+ elif model_to_use == "gpt-3.5-turbo" or model_to_use == "gpt-3.5-turbo-16k":
+ if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
+
+ model_to_use = "gpt-3.5-turbo-16k"
total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents)
- if total_tokens > sdk.models.gpt3516k.max_tokens:
- while cur_end_line > min_end_line:
- total_tokens -= len(full_file_contents_lst[cur_end_line])
- cur_end_line -= 1
- if total_tokens < sdk.models.gpt4.max_tokens:
- break
-
- if total_tokens > sdk.models.gpt3516k.max_tokens:
- while cur_start_line < max_start_line:
- total_tokens -= len(full_file_contents_lst[cur_start_line])
- cur_start_line += 1
- if total_tokens < sdk.models.gpt4.max_tokens:
- break
+ cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens)
+
else:
raise Exception("Unknown default model")
- code_before = "".join(full_file_contents_lst[min_end_line:])
- code_after = "".join(full_file_contents_lst[:max_start_line])
+ code_before = "".join(full_file_contents_lst[cur_end_line:])
+ code_after = "".join(full_file_contents_lst[:cur_start_line])
segs = [code_before, code_after]