diff options
author | Ty Dunn <ty@tydunn.com> | 2023-06-16 12:47:20 -0700 |
---|---|---|
committer | Ty Dunn <ty@tydunn.com> | 2023-06-16 12:47:20 -0700 |
commit | a72a163b461cfa5f610898aad40ca8c7ff09f0ab (patch) | |
tree | 8f1a20a4f7d65326cf2514b44137568fe6256ad6 /continuedev/src | |
parent | a5a5a08d23adbde9be0d67cc0f46e9742de01f96 (diff) | |
download | sncontinue-a72a163b461cfa5f610898aad40ca8c7ff09f0ab.tar.gz sncontinue-a72a163b461cfa5f610898aad40ca8c7ff09f0ab.tar.bz2 sncontinue-a72a163b461cfa5f610898aad40ca8c7ff09f0ab.zip |
getting it working
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 44 |
2 files changed, 30 insertions, 26 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 17d37035..136e86b4 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -9,12 +9,12 @@ from ..llm import LLM from pydantic import BaseModel, validator import tiktoken +DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192 + "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, + "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, + "gpt-4": 8192 - DEFAULT_MAX_TOKENS } -DEFAULT_MAX_TOKENS = 2048 CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4" } @@ -33,6 +33,10 @@ class OpenAI(LLM): openai.api_key = api_key @cached_property + def name(self): + return self.default_model + + @cached_property def __encoding_for_model(self): aliases = { "gpt-3.5-turbo": "gpt3" diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index ec0007a2..de6fa29a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -175,8 +175,8 @@ class DefaultModelEditCodeStep(Step): for rif in rif_with_contents: await sdk.ide.setFileOpen(rif.filepath) - model_to_use = sdk.config.default_model - + model_to_use = sdk.models.default + full_file_contents = await sdk.ide.readFile(rif.filepath) full_file_contents_lst = full_file_contents.split("\n") @@ -184,45 +184,45 @@ class DefaultModelEditCodeStep(Step): max_start_line = rif.range.start.line min_end_line = rif.range.end.line cur_start_line = 0 - cur_end_line = len(full_file_contents_lst) + cur_end_line = len(full_file_contents_lst) - 1 - def cut_context(model_to_use, total_tokens): - - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: + def cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line): + + if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: while cur_end_line > min_end_line: - total_tokens -= len(full_file_contents_lst[cur_end_line]) + total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line]) cur_end_line -= 1 - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: + if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: return cur_start_line, cur_end_line - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use]: + if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: while cur_start_line < max_start_line: cur_start_line += 1 - total_tokens -= len(full_file_contents_lst[cur_start_line]) - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use]: + total_tokens -= model_to_use.count_tokens(full_file_contents_lst[cur_end_line]) + if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: return cur_start_line, cur_end_line - return cur_start_line, cur_end_line - else: + return cur_start_line, cur_end_line - if model_to_use == "gpt-4": + if model_to_use.name == "gpt-4": - total_tokens = sdk.models.gpt4.count_tokens(full_file_contents) - cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) + total_tokens = model_to_use.count_tokens(full_file_contents) + cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line) - elif model_to_use == "gpt-3.5-turbo" or model_to_use == "gpt-3.5-turbo-16k": + elif model_to_use.name == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k": if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: - model_to_use = "gpt-3.5-turbo-16k" - total_tokens = sdk.models.gpt3516k.count_tokens(full_file_contents) - cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens) + model_to_use = sdk.models.gpt3516k + total_tokens = model_to_use.count_tokens(full_file_contents) + cur_start_line, cur_end_line = cut_context(model_to_use, total_tokens, cur_start_line, cur_end_line) else: + raise Exception("Unknown default model") - code_before = "".join(full_file_contents_lst[cur_end_line:]) - code_after = "".join(full_file_contents_lst[:cur_start_line]) + code_before = "".join(full_file_contents_lst[cur_start_line:max_start_line]) + code_after = "".join(full_file_contents_lst[min_end_line:cur_end_line]) segs = [code_before, code_after] |