summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py4
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py14
-rw-r--r--continuedev/src/continuedev/steps/core/core.py22
3 files changed, 20 insertions, 20 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index f0b2e6d8..3024ae61 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -4,7 +4,7 @@ from typing import Any, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens
+from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
class OpenAI(LLM):
@@ -72,7 +72,7 @@ class OpenAI(LLM):
)).choices[0].message.content
else:
resp = (await openai.Completion.acreate(
- prompt=prompt,
+ prompt=prune_raw_prompt_from_top(args["model"], prompt),
**args,
)).choices[0].text
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 6038b68d..addafcff 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -28,6 +28,16 @@ def count_tokens(model: str, text: str | None):
return len(encoding.encode(text, disallowed_special=()))
+def prune_raw_prompt_from_top(model: str, prompt: str):
+ max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS)
+ encoding = encoding_for_model(model)
+ tokens = encoding.encode(prompt, disallowed_special=())
+ if len(tokens) <= max_tokens:
+ return prompt
+ else:
+ return encoding.decode(tokens[-max_tokens:])
+
+
def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
sum(count_tokens(model, message.content)
@@ -43,13 +53,13 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens:
+ while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
message = chat_history.pop(0)
total_tokens -= count_tokens(model, message.content)
# 3. Truncate message in the last 5
i = 0
- while total_tokens > max_tokens:
+ while total_tokens > max_tokens and len(chat_history) > 0:
message = chat_history[0]
total_tokens -= count_tokens(model, message.content)
total_tokens += count_tokens(model, message.summary)
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index f146c94a..24f00d36 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -211,26 +211,16 @@ class DefaultModelEditCodeStep(Step):
return cur_start_line, cur_end_line
- if model_to_use.name == "gpt-4":
-
- total_tokens = model_to_use.count_tokens(
- full_file_contents + self._prompt)
- cur_start_line, cur_end_line = cut_context(
- model_to_use, total_tokens, cur_start_line, cur_end_line)
-
- elif model_to_use.name == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k":
-
+ model_to_use = sdk.models.default
+ if model_to_use.name == "gpt-3.5-turbo":
if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
-
model_to_use = sdk.models.gpt3516k
- total_tokens = model_to_use.count_tokens(
- full_file_contents + self._prompt)
- cur_start_line, cur_end_line = cut_context(
- model_to_use, total_tokens, cur_start_line, cur_end_line)
- else:
+ total_tokens = model_to_use.count_tokens(
+ full_file_contents + self._prompt + self.user_input)
- raise Exception("Unknown default model")
+ cur_start_line, cur_end_line = cut_context(
+ model_to_use, total_tokens, cur_start_line, cur_end_line)
code_before = "\n".join(
full_file_contents_lst[cur_start_line:max_start_line])