diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 14 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 22 | 
3 files changed, 20 insertions, 20 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0b2e6d8..3024ae61 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -4,7 +4,7 @@ from typing import Any, Coroutine, Dict, Generator, List, Union  from ...core.main import ChatMessage  import openai  from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top  class OpenAI(LLM): @@ -72,7 +72,7 @@ class OpenAI(LLM):              )).choices[0].message.content          else:              resp = (await openai.Completion.acreate( -                prompt=prompt, +                prompt=prune_raw_prompt_from_top(args["model"], prompt),                  **args,              )).choices[0].text diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 6038b68d..addafcff 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -28,6 +28,16 @@ def count_tokens(model: str, text: str | None):      return len(encoding.encode(text, disallowed_special=())) +def prune_raw_prompt_from_top(model: str, prompt: str): +    max_tokens = MAX_TOKENS_FOR_MODEL.get(model, DEFAULT_MAX_TOKENS) +    encoding = encoding_for_model(model) +    tokens = encoding.encode(prompt, disallowed_special=()) +    if len(tokens) <= max_tokens: +        return prompt +    else: +        return encoding.decode(tokens[-max_tokens:]) + +  def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):      total_tokens = tokens_for_completion + \          sum(count_tokens(model, message.content) @@ -43,13 +53,13 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:          i += 1      # 2. Remove entire messages until the last 5 -    while len(chat_history) > 5 and total_tokens > max_tokens: +    while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:          message = chat_history.pop(0)          total_tokens -= count_tokens(model, message.content)      # 3. Truncate message in the last 5      i = 0 -    while total_tokens > max_tokens: +    while total_tokens > max_tokens and len(chat_history) > 0:          message = chat_history[0]          total_tokens -= count_tokens(model, message.content)          total_tokens += count_tokens(model, message.summary) diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index f146c94a..24f00d36 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -211,26 +211,16 @@ class DefaultModelEditCodeStep(Step):                  return cur_start_line, cur_end_line -            if model_to_use.name == "gpt-4": - -                total_tokens = model_to_use.count_tokens( -                    full_file_contents + self._prompt) -                cur_start_line, cur_end_line = cut_context( -                    model_to_use, total_tokens, cur_start_line, cur_end_line) - -            elif model_to_use.name == "gpt-3.5-turbo" or model_to_use.name == "gpt-3.5-turbo-16k": - +            model_to_use = sdk.models.default +            if model_to_use.name == "gpt-3.5-turbo":                  if sdk.models.gpt35.count_tokens(full_file_contents) > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: -                      model_to_use = sdk.models.gpt3516k -                    total_tokens = model_to_use.count_tokens( -                        full_file_contents + self._prompt) -                    cur_start_line, cur_end_line = cut_context( -                        model_to_use, total_tokens, cur_start_line, cur_end_line) -            else: +            total_tokens = model_to_use.count_tokens( +                full_file_contents + self._prompt + self.user_input) -                raise Exception("Unknown default model") +            cur_start_line, cur_end_line = cut_context( +                model_to_use, total_tokens, cur_start_line, cur_end_line)              code_before = "\n".join(                  full_file_contents_lst[cur_start_line:max_start_line]) | 
