summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-30 17:47:17 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-30 17:47:17 -0700
commit39076efbd74106ad59ad65e31d52b8d591c1d485 (patch)
tree188dd70e4a8b4b1633c6d21bc5d4d81a206748c4 /continuedev/src
parent798e94f62b2c64762e2e6f79645e9334013ac7a8 (diff)
downloadsncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.tar.gz
sncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.tar.bz2
sncontinue-39076efbd74106ad59ad65e31d52b8d591c1d485.zip
refactor: :recycle: clean up LLM-specific constants from util files
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py24
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py17
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py35
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py82
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py15
10 files changed, 102 insertions, 108 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index f81fa57a..8afbd610 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -169,12 +169,6 @@ class ContextManager:
async with Client('http://localhost:7700') as search_client:
await search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
- # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- # """
- # Compiles the chat prompt into a single string.
- # """
- # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
async def select_context_item(self, id: str, query: str):
"""
Selects the ContextItem with the given id.
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 58572634..96e88383 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -43,3 +43,8 @@ class LLM(BaseModel, ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError
+
+ @abstractproperty
+ def context_length(self) -> int:
+ """Return the context length of the LLM in tokens, as counted by count_tokens."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index c9c8e9db..4444fd1b 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -5,7 +5,7 @@ from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
class AnthropicLLM(LLM):
@@ -46,6 +46,12 @@ class AnthropicLLM(LLM):
def count_tokens(self, text: str):
return count_tokens(self.model, text)
+ @property
+ def context_length(self):
+ if self.model == "claude-2":
+ return 100000
+ raise Exception(f"Unknown Anthropic model {self.model}")
+
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -77,7 +83,7 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
@@ -92,7 +98,7 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 990f35bc..7fa51e34 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -29,6 +29,10 @@ class GGML(LLM):
return "ggml"
@property
+ def context_length(self):
+ return 2048
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
@@ -42,7 +46,7 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
# TODO move to single self.session variable (proxy setting etc)
async with self._client_session as session:
@@ -60,7 +64,7 @@ class GGML(LLM):
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
async with self._client_session as session:
@@ -87,7 +91,7 @@ class GGML(LLM):
async with self._client_session as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args
}) as resp:
try:
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index 121ae99e..f5b3c18c 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -19,6 +19,10 @@ class MaybeProxyOpenAI(LLM):
def name(self):
return self.llm.name
+ @property
+ def context_length(self):
+ return self.llm.context_length
+
async def start(self, *, api_key: Optional[str] = None, **kwargs):
if api_key is None or api_key.strip() == "":
self.llm = ProxyServer(
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index de02a614..deb6df4c 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -6,7 +6,17 @@ from pydantic import BaseModel
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
class AzureInfo(BaseModel):
@@ -44,6 +54,10 @@ class OpenAI(LLM):
return self.model
@property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
+
+ @property
def default_args(self):
args = {**DEFAULT_ARGS, "model": self.model}
if self.azure_info is not None:
@@ -60,7 +74,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -93,7 +107,7 @@ class OpenAI(LLM):
del args["functions"]
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -110,7 +124,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
messages=messages,
@@ -119,7 +133,7 @@ class OpenAI(LLM):
self.write_log(f"Completion: \n\n{resp}")
else:
prompt = prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"])
+ args["model"], self.context_length, prompt, args["max_tokens"])
self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
prompt=prompt,
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 1c942523..56b123db 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -15,6 +15,13 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
class ProxyServer(LLM):
model: str
@@ -41,6 +48,10 @@ class ProxyServer(LLM):
return self.model
@property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.model}
@@ -55,7 +66,7 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with self._client_session as session:
async with session.post(f"{SERVER_URL}/complete", json={
@@ -72,7 +83,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with self._client_session as session:
@@ -107,7 +118,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with self._client_session as session:
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
deleted file mode 100644
index 4ea45b7b..00000000
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from transformers import GPT2TokenizerFast
-
-gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
-def count_tokens(text: str) -> int:
- return len(gpt2_tokenizer.encode(text))
-
-# TODO move this to LLM class itself (especially as prices may change in the future)
-prices = {
- # All prices are per 1k tokens
- "fine-tune-train": {
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
- },
- "completion": {
- "davinci": 0.02,
- "curie": 0.002,
- "babbage": 0.0005,
- "ada": 0.0004,
- },
- "fine-tune-completion": {
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
- },
- "embedding": {
- "ada": 0.0004
- }
-}
-
-def get_price(text: str, model: str="davinci", task: str="completion") -> float:
- return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index f6c7cb00..6add7b1a 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -2,6 +2,7 @@ import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
from .templating import render_templated_string
+from ...libs.llm import LLM
import tiktoken
# TODO move many of these into specific LLM.properties() function that
@@ -13,36 +14,35 @@ aliases = {
"claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192,
- "ggml": 2048,
- "claude-2": 100000
-}
-CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
-}
DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0}
-def encoding_for_model(model: str):
- return tiktoken.encoding_for_model(aliases.get(model, model))
+def encoding_for_model(model_name: str):
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
-def count_tokens(model: str, text: Union[str, None]):
+def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
- encoding = encoding_for_model(model)
+ encoding = encoding_for_model(model_name)
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(
- model, DEFAULT_MAX_TOKENS) - tokens_for_completion
- encoding = encoding_for_model(model)
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt
@@ -50,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
-def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
- # Doing simpler, safer version of what is here:
- # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- # every message follows <|start|>{role/name}\n{content}<|end|>\n
- TOKENS_PER_MESSAGE = 4
- return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
-
-
-def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_chat_message_tokens(model, message)
+ sum(count_chat_message_tokens(model_name, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
i = 0
- while total_tokens > max_tokens and i < len(chat_history) - 5:
+ while total_tokens > context_length and i < len(chat_history) - 5:
message = chat_history[0]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
+ while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 3. Truncate message in the last 5, except last 1
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1:
+ while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:
message = chat_history[i]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 4. Remove entire messages in the last 5, except last 1
- while total_tokens > max_tokens and len(chat_history) > 1:
+ while total_tokens > context_length and len(chat_history) > 1:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 5. Truncate last message
- if total_tokens > max_tokens and len(chat_history) > 0:
+ if total_tokens > context_length and len(chat_history) > 0:
message = chat_history[0]
message.content = prune_raw_prompt_from_top(
- model, message.content, tokens_for_completion)
- total_tokens = max_tokens
+ model_name, context_length, message.content, tokens_for_completion)
+ total_tokens = context_length
return chat_history
@@ -105,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
@@ -129,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_
function_tokens = 0
if functions is not None:
for function in functions:
- function_tokens += count_tokens(model, json.dumps(function))
+ function_tokens += count_tokens(model_name, json.dumps(function))
msgs_copy = prune_chat_history(
- model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
for msg in msgs_copy]
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index b9f27fe5..4c5303fb 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -12,7 +12,7 @@ from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConte
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
-from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
+from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
import difflib
@@ -182,8 +182,7 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
- model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
+ max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -201,7 +200,7 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
- if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
+ if total_tokens > model_to_use.context_length:
model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
await sdk.start_model(model_to_use)
@@ -213,20 +212,20 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst) - 1
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_end_line > min_end_line:
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_start_line])
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
# Now use the found start/end lines to get the prefix and suffix strings