diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-30 17:47:17 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-30 17:47:17 -0700 |
commit | 39076efbd74106ad59ad65e31d52b8d591c1d485 (patch) | |
tree | 188dd70e4a8b4b1633c6d21bc5d4d81a206748c4 /continuedev/src | |
parent | 798e94f62b2c64762e2e6f79645e9334013ac7a8 (diff) | |
download | sncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.tar.gz sncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.tar.bz2 sncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.zip |
refactor: :recycle: clean up LLM-specific constants from util files
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 24 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 17 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/utils.py | 35 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 82 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 15 |
10 files changed, 102 insertions, 108 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index f81fa57a..8afbd610 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -169,12 +169,6 @@ class ContextManager: async with Client('http://localhost:7700') as search_client: await search_client.index(SEARCH_INDEX_NAME).add_documents(documents) - # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: - # """ - # Compiles the chat prompt into a single string. - # """ - # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) - async def select_context_item(self, id: str, query: str): """ Selects the ContextItem with the given id. diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 58572634..96e88383 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -43,3 +43,8 @@ class LLM(BaseModel, ABC): def count_tokens(self, text: str): """Return the number of tokens in the given text.""" raise NotImplementedError + + @abstractproperty + def context_length(self) -> int: + """Return the context length of the LLM in tokens, as counted by count_tokens.""" + raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index c9c8e9db..4444fd1b 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -5,7 +5,7 @@ from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens class AnthropicLLM(LLM): @@ -46,6 +46,12 @@ class AnthropicLLM(LLM): def count_tokens(self, text: str): return count_tokens(self.model, text) + @property + def context_length(self): + if self.model == "claude-2": + return 100000 + raise Exception(f"Unknown Anthropic model {self.model}") + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -77,7 +83,7 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) async for chunk in await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args @@ -92,7 +98,7 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) resp = (await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 990f35bc..7fa51e34 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -29,6 +29,10 @@ class GGML(LLM): return "ggml" @property + def context_length(self): + return 2048 + + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -42,7 +46,7 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) # TODO move to single self.session variable (proxy setting etc) async with self._client_session as session: @@ -60,7 +64,7 @@ class GGML(LLM): async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True async with self._client_session as session: @@ -87,7 +91,7 @@ class GGML(LLM): async with self._client_session as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args }) as resp: try: diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index 121ae99e..f5b3c18c 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -19,6 +19,10 @@ class MaybeProxyOpenAI(LLM): def name(self): return self.llm.name + @property + def context_length(self): + return self.llm.context_length + async def start(self, *, api_key: Optional[str] = None, **kwargs): if api_key is None or api_key.strip() == "": self.llm = ProxyServer( diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index de02a614..deb6df4c 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -6,7 +6,17 @@ from pydantic import BaseModel from ...core.main import ChatMessage import openai from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top + +CHAT_MODELS = { + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} class AzureInfo(BaseModel): @@ -44,6 +54,10 @@ class OpenAI(LLM): return self.model @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] + + @property def default_args(self): args = {**DEFAULT_ARGS, "model": self.model} if self.azure_info is not None: @@ -60,7 +74,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -93,7 +107,7 @@ class OpenAI(LLM): del args["functions"] messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -110,7 +124,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = (await openai.ChatCompletion.acreate( messages=messages, @@ -119,7 +133,7 @@ class OpenAI(LLM): self.write_log(f"Completion: \n\n{resp}") else: prompt = prune_raw_prompt_from_top( - args["model"], prompt, args["max_tokens"]) + args["model"], self.context_length, prompt, args["max_tokens"]) self.write_log(f"Prompt:\n\n{prompt}") resp = (await openai.Completion.acreate( prompt=prompt, diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 1c942523..56b123db 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -15,6 +15,13 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path) # SERVER_URL = "http://127.0.0.1:8080" SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + class ProxyServer(LLM): model: str @@ -41,6 +48,10 @@ class ProxyServer(LLM): return self.model @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] + + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.model} @@ -55,7 +66,7 @@ class ProxyServer(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: async with session.post(f"{SERVER_URL}/complete", json={ @@ -72,7 +83,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: @@ -107,7 +118,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py deleted file mode 100644 index 4ea45b7b..00000000 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GPT2TokenizerFast - -gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") -def count_tokens(text: str) -> int: - return len(gpt2_tokenizer.encode(text)) - -# TODO move this to LLM class itself (especially as prices may change in the future) -prices = { - # All prices are per 1k tokens - "fine-tune-train": { - "davinci": 0.03, - "curie": 0.03, - "babbage": 0.0006, - "ada": 0.0004, - }, - "completion": { - "davinci": 0.02, - "curie": 0.002, - "babbage": 0.0005, - "ada": 0.0004, - }, - "fine-tune-completion": { - "davinci": 0.12, - "curie": 0.012, - "babbage": 0.0024, - "ada": 0.0016, - }, - "embedding": { - "ada": 0.0004 - } -} - -def get_price(text: str, model: str="davinci", task: str="completion") -> float: - return count_tokens(text) * prices[task][model] / 1000
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index f6c7cb00..6add7b1a 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -2,6 +2,7 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage from .templating import render_templated_string +from ...libs.llm import LLM import tiktoken # TODO move many of these into specific LLM.properties() function that @@ -13,36 +14,35 @@ aliases = { "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192, - "ggml": 2048, - "claude-2": 100000 -} -CHAT_MODELS = { - "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -} DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0} -def encoding_for_model(model: str): - return tiktoken.encoding_for_model(aliases.get(model, model)) +def encoding_for_model(model_name: str): + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except: + return tiktoken.encoding_for_model("gpt-3.5-turbo") -def count_tokens(model: str, text: Union[str, None]): +def count_tokens(model_name: str, text: Union[str, None]): if text is None: return 0 - encoding = encoding_for_model(model) + encoding = encoding_for_model(model_name) return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): - max_tokens = MAX_TOKENS_FOR_MODEL.get( - model, DEFAULT_MAX_TOKENS) - tokens_for_completion - encoding = encoding_for_model(model) +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): + max_tokens = context_length - tokens_for_completion + encoding = encoding_for_model(model_name) tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: return prompt @@ -50,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) -def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: - # Doing simpler, safer version of what is here: - # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - # every message follows <|start|>{role/name}\n{content}<|end|>\n - TOKENS_PER_MESSAGE = 4 - return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE - - -def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_chat_message_tokens(model, message) + sum(count_chat_message_tokens(model_name, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary i = 0 - while total_tokens > max_tokens and i < len(chat_history) - 5: + while total_tokens > context_length and i < len(chat_history) - 5: message = chat_history[0] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 2. Remove entire messages until the last 5 - while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0: + while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 3. Truncate message in the last 5, except last 1 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: + while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1: message = chat_history[i] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 4. Remove entire messages in the last 5, except last 1 - while total_tokens > max_tokens and len(chat_history) > 1: + while total_tokens > context_length and len(chat_history) > 1: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 5. Truncate last message - if total_tokens > max_tokens and len(chat_history) > 0: + if total_tokens > context_length and len(chat_history) > 0: message = chat_history[0] message.content = prune_raw_prompt_from_top( - model, message.content, tokens_for_completion) - total_tokens = max_tokens + model_name, context_length, message.content, tokens_for_completion) + total_tokens = context_length return chat_history @@ -105,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ @@ -129,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_ function_tokens = 0 if functions is not None: for function in functions: - function_tokens += count_tokens(model, json.dumps(function)) + function_tokens += count_tokens(model_name, json.dumps(function)) msgs_copy = prune_chat_history( - model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index b9f27fe5..4c5303fb 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -12,7 +12,7 @@ from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConte from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep -from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes import difflib @@ -182,8 +182,7 @@ class DefaultModelEditCodeStep(Step): # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. model_to_use = sdk.models.default - max_tokens = int(MAX_TOKENS_FOR_MODEL.get( - model_to_use.name, DEFAULT_MAX_TOKENS) / 2) + max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -201,7 +200,7 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": - if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: + if total_tokens > model_to_use.context_length: model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(model_to_use) @@ -213,20 +212,20 @@ class DefaultModelEditCodeStep(Step): cur_start_line = 0 cur_end_line = len(full_file_contents_lst) - 1 - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_end_line > min_end_line: total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_end_line]) cur_end_line -= 1 - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_start_line < max_start_line: cur_start_line += 1 total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_start_line]) - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break # Now use the found start/end lines to get the prefix and suffix strings |