summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-15 21:26:17 -0700
committerNate Sesti <sestinj@gmail.com>2023-06-15 21:26:17 -0700
commitd4f416b794935f4f5c3607c8569a0c77891c1eb6 (patch)
treefa4b26778aff290af9f919b3631740c5679f17d1 /continuedev/src
parentccca93eb18ceac9769ebac380bca47f21a691d99 (diff)
downloadsncontinue-d4f416b794935f4f5c3607c8569a0c77891c1eb6.tar.gz
sncontinue-d4f416b794935f4f5c3607c8569a0c77891c1eb6.tar.bz2
sncontinue-d4f416b794935f4f5c3607c8569a0c77891c1eb6.zip
better algorithm for pruning chat context
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/main.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py19
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py37
3 files changed, 45 insertions, 15 deletions
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 97ef9793..0c7ec67f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -11,6 +11,8 @@ ChatMessageRole = Literal["assistant", "user", "system"]
class ChatMessage(ContinueBaseModel):
role: ChatMessageRole
content: str
+ # A summary for pruning chat context to fit context window. Often the Step name.
+ summary: str
class HistoryNode(ContinueBaseModel):
@@ -23,7 +25,7 @@ class HistoryNode(ContinueBaseModel):
def to_chat_messages(self) -> List[ChatMessage]:
if self.step.description is None:
return self.step.chat_context
- return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description)]
+ return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description, summary=self.step.name)]
class History(ContinueBaseModel):
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 76f72d01..8aea6b7f 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -165,26 +165,29 @@ class ContinueSDK(AbstractContinueSDK):
def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None):
raise ContinueCustomException(message, title, with_step)
- def add_chat_context(self, content: str, role: ChatMessageRole = "assistent"):
+ def add_chat_context(self, content: str, summary: Union[str, None] = None, role: ChatMessageRole = "assistent"):
self.history.timeline[self.history.current_index].step.chat_context.append(
- ChatMessage(content=content, role=role))
+ ChatMessage(content=content, role=role, summary=summary))
async def get_chat_context(self) -> List[ChatMessage]:
history_context = self.history.to_chat_history()
highlighted_code = await self.ide.getHighlightedCode()
+
+ preface = "The following code is highlighted"
+
if len(highlighted_code) == 0:
+ preface = "The following file is open"
# Get the full contents of all open files
files = await self.ide.getOpenFiles()
- contents = {}
- for file in files:
- contents[file] = await self.ide.readFile(file)
+ if len(files) > 0:
+ content = await self.ide.readFile(files[0])
+ highlighted_code = [
+ RangeInFile.from_entire_file(files[0], content)]
- highlighted_code = [RangeInFile.from_entire_file(
- filepath, content) for filepath, content in contents.items()]
for rif in highlighted_code:
code = await self.ide.readRangeInFile(rif)
history_context.append(ChatMessage(
- content=f"The following code is highlighted:\n```\n{code}\n```", role="user"))
+ content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", role="user", summary=f"{preface}: {rif.filepath}"))
return history_context
async def update_ui(self):
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index bc108129..d457451f 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -42,12 +42,37 @@ class OpenAI(LLM):
return len(self.__encoding_for_model.encode(text, disallowed_special=()))
def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
- tokens = tokens_for_completion
- for i in range(len(chat_history) - 1, -1, -1):
- message = chat_history[i]
- tokens += self.count_tokens(message.content)
- if tokens > max_tokens:
- return chat_history[i + 1:]
+ total_tokens = tokens_for_completion + \
+ sum(self.count_tokens(message.content) for message in chat_history)
+
+ # 1. Replace beyond last 5 messages with summary
+ i = 0
+ while total_tokens > max_tokens and i < len(chat_history) - 5:
+ message = chat_history[0]
+ total_tokens -= self.count_tokens(message.content)
+ total_tokens += self.count_tokens(message.summary)
+ message.content = message.summary
+ i += 1
+
+ # 2. Remove entire messages until the last 5
+ while len(chat_history) > 5 and total_tokens > max_tokens:
+ message = chat_history.pop(0)
+ total_tokens -= self.count_tokens(message.content)
+
+ # 3. Truncate message in the last 5
+ i = 0
+ while total_tokens > max_tokens:
+ message = chat_history[0]
+ total_tokens -= self.count_tokens(message.content)
+ total_tokens += self.count_tokens(message.summary)
+ message.content = message.summary
+ i += 1
+
+ # 4. Remove entire messages in the last 5
+ while total_tokens > max_tokens and len(chat_history) > 0:
+ message = chat_history.pop(0)
+ total_tokens -= self.count_tokens(message.content)
+
return chat_history
def with_system_message(self, system_message: Union[str, None]):