diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 37 | 
3 files changed, 45 insertions, 15 deletions
| diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 97ef9793..0c7ec67f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -11,6 +11,8 @@ ChatMessageRole = Literal["assistant", "user", "system"]  class ChatMessage(ContinueBaseModel):      role: ChatMessageRole      content: str +    # A summary for pruning chat context to fit context window. Often the Step name. +    summary: str  class HistoryNode(ContinueBaseModel): @@ -23,7 +25,7 @@ class HistoryNode(ContinueBaseModel):      def to_chat_messages(self) -> List[ChatMessage]:          if self.step.description is None:              return self.step.chat_context -        return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description)] +        return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description, summary=self.step.name)]  class History(ContinueBaseModel): diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 76f72d01..8aea6b7f 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -165,26 +165,29 @@ class ContinueSDK(AbstractContinueSDK):      def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None):          raise ContinueCustomException(message, title, with_step) -    def add_chat_context(self, content: str, role: ChatMessageRole = "assistent"): +    def add_chat_context(self, content: str, summary: Union[str, None] = None, role: ChatMessageRole = "assistent"):          self.history.timeline[self.history.current_index].step.chat_context.append( -            ChatMessage(content=content, role=role)) +            ChatMessage(content=content, role=role, summary=summary))      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history()          highlighted_code = await self.ide.getHighlightedCode() + +        preface = "The following code is highlighted" +          if len(highlighted_code) == 0: +            preface = "The following file is open"              # Get the full contents of all open files              files = await self.ide.getOpenFiles() -            contents = {} -            for file in files: -                contents[file] = await self.ide.readFile(file) +            if len(files) > 0: +                content = await self.ide.readFile(files[0]) +                highlighted_code = [ +                    RangeInFile.from_entire_file(files[0], content)] -            highlighted_code = [RangeInFile.from_entire_file( -                filepath, content) for filepath, content in contents.items()]          for rif in highlighted_code:              code = await self.ide.readRangeInFile(rif)              history_context.append(ChatMessage( -                content=f"The following code is highlighted:\n```\n{code}\n```", role="user")) +                content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", role="user", summary=f"{preface}: {rif.filepath}"))          return history_context      async def update_ui(self): diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index bc108129..d457451f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -42,12 +42,37 @@ class OpenAI(LLM):          return len(self.__encoding_for_model.encode(text, disallowed_special=()))      def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): -        tokens = tokens_for_completion -        for i in range(len(chat_history) - 1, -1, -1): -            message = chat_history[i] -            tokens += self.count_tokens(message.content) -            if tokens > max_tokens: -                return chat_history[i + 1:] +        total_tokens = tokens_for_completion + \ +            sum(self.count_tokens(message.content) for message in chat_history) + +        # 1. Replace beyond last 5 messages with summary +        i = 0 +        while total_tokens > max_tokens and i < len(chat_history) - 5: +            message = chat_history[0] +            total_tokens -= self.count_tokens(message.content) +            total_tokens += self.count_tokens(message.summary) +            message.content = message.summary +            i += 1 + +        # 2. Remove entire messages until the last 5 +        while len(chat_history) > 5 and total_tokens > max_tokens: +            message = chat_history.pop(0) +            total_tokens -= self.count_tokens(message.content) + +        # 3. Truncate message in the last 5 +        i = 0 +        while total_tokens > max_tokens: +            message = chat_history[0] +            total_tokens -= self.count_tokens(message.content) +            total_tokens += self.count_tokens(message.summary) +            message.content = message.summary +            i += 1 + +        # 4. Remove entire messages in the last 5 +        while total_tokens > max_tokens and len(chat_history) > 0: +            message = chat_history.pop(0) +            total_tokens -= self.count_tokens(message.content) +          return chat_history      def with_system_message(self, system_message: Union[str, None]): | 
