summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-13 13:56:51 -0700
committerNate Sesti <sestinj@gmail.com>2023-06-13 13:56:51 -0700
commitfe4306b0c15761bbe0cace92050bc3fd216c7faf (patch)
treeffc1a274a62fca7386ad003ceef731305fb212f7 /continuedev
parent8842e201eec31435bd0c28cea25c90ede56aa90e (diff)
downloadsncontinue-fe4306b0c15761bbe0cace92050bc3fd216c7faf.tar.gz
sncontinue-fe4306b0c15761bbe0cace92050bc3fd216c7faf.tar.bz2
sncontinue-fe4306b0c15761bbe0cace92050bc3fd216c7faf.zip
prune chat context to fit in token limit
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/poetry.lock2
-rw-r--r--continuedev/pyproject.toml1
-rw-r--r--continuedev/src/continuedev/core/config.py3
-rw-r--r--continuedev/src/continuedev/core/sdk.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py30
-rw-r--r--continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py2
-rw-r--r--continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py7
-rw-r--r--continuedev/src/continuedev/steps/chat.py2
9 files changed, 45 insertions, 10 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock
index 4aedce87..93aaf82b 100644
--- a/continuedev/poetry.lock
+++ b/continuedev/poetry.lock
@@ -1737,4 +1737,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
-content-hash = "0f5f759bac0e44a1fbcc9babeccdea8688ea2226a4bae7a13858542ae03a3228"
+content-hash = "17910714e3ad780ae7222b62c98539489d198aea67e5c7e4a9fc7672207f500f"
diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml
index 7315e79d..af6ff045 100644
--- a/continuedev/pyproject.toml
+++ b/continuedev/pyproject.toml
@@ -20,6 +20,7 @@ websockets = "^11.0.2"
urllib3 = "1.26.15"
gpt-index = "^0.6.8"
posthog = "^3.0.1"
+tiktoken = "^0.4.0"
[tool.poetry.scripts]
typegen = "src.continuedev.models.generate_json_schema:main"
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index ebd1fdb8..a54a823b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -19,7 +19,8 @@ class ContinueConfig(BaseModel):
steps_on_startup: Optional[Dict[str, Dict]] = {}
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-4", "starcoder"]
+ default_model: Literal["gpt-3.5-turbo",
+ "gpt-4", "starcoder"] = 'gpt-3.5-turbo'
slash_commands: Optional[List[SlashCommand]] = [
# SlashCommand(
# name="pytest",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index a54048b0..1da190ff 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -52,9 +52,9 @@ class Models:
def __model_from_name(self, model_name: str):
if model_name == "starcoder":
return self.starcoder
- elif model_name == "gpt35":
+ elif model_name == "gpt-3.5-turbo":
return self.gpt35
- elif model_name == "gpt4":
+ elif model_name == "gpt-4":
return self.gpt4
else:
raise Exception(f"Unknown model {model_name}")
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 24fd34be..4889c01e 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -22,3 +22,7 @@ class LLM(BaseModel):
def with_system_message(self, system_message: Union[str, None]):
"""Return a new model with the given system message."""
raise NotImplementedError
+
+ def count_tokens(self, text: str):
+ """Return the number of tokens in the given text."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 9b8d3447..39c0b69f 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,4 +1,5 @@
import asyncio
+from functools import cached_property
import time
from typing import Any, Dict, Generator, List, Union
from ...core.main import ChatMessage
@@ -6,20 +7,44 @@ import openai
import aiohttp
from ..llm import LLM
from pydantic import BaseModel, validator
+import tiktoken
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4097,
+ "gpt-4": 4097,
+}
DEFAULT_MAX_TOKENS = 2048
class OpenAI(LLM):
api_key: str
completion_count: int = 0
- default_model: str = "text-davinci-003"
+ default_model: str = "gpt-3.5-turbo"
@validator("api_key", pre=True, always=True)
def validate_api_key(cls, v):
openai.api_key = v
return v
+ @cached_property
+ def __encoding_for_model(self):
+ aliases = {
+ "gpt-3.5-turbo": "gpt3"
+ }
+ return tiktoken.encoding_for_model(self.default_model)
+
+ def count_tokens(self, text: str):
+ return len(self.__encoding_for_model.encode(text))
+
+ def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+ tokens = tokens_for_completion
+ for i in range(len(chat_history) - 1, -1, -1):
+ message = chat_history[i]
+ tokens += self.count_tokens(message.content)
+ if tokens > max_tokens:
+ return chat_history[i + 1:]
+ return chat_history
+
def with_system_message(self, system_message: Union[str, None]):
return OpenAI(api_key=self.api_key, system_message=system_message)
@@ -40,6 +65,8 @@ class OpenAI(LLM):
continue
def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]:
+ msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens(
+ prompt) + 1000 + self.count_tokens(self.system_message or ""))
history = []
if self.system_message:
history.append({
@@ -51,6 +78,7 @@ class OpenAI(LLM):
"role": "user",
"content": prompt
})
+
return history
def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
index 6426df87..6db9fd4b 100644
--- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
@@ -82,7 +82,7 @@ class LoadDataStep(Step):
docs = f.read()
output = "Traceback" + output.split("Traceback")[-1]
- suggestion = sdk.models.gpt4.complete(dedent(f"""\
+ suggestion = sdk.models.default.complete(dedent(f"""\
When trying to load data into BigQuery, the following error occurred:
```ascii
diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py
index 2f346b49..d9bdbc0a 100644
--- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py
@@ -65,9 +65,10 @@ class DeployAirflowStep(Step):
async def run(self, sdk: ContinueSDK):
# Run dlt command to deploy pipeline to Airflow
- await sdk.run([
- f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer',
- ], description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow")
+ await sdk.run(
+ ['git init',
+ f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer'],
+ description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow")
# Get filepaths, open the DAG file
directory = await sdk.ide.getWorkspaceDirectory()
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index aca0ab2a..7cfe7e0c 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -10,7 +10,7 @@ class SimpleChatStep(Step):
name: str = "Chat"
async def run(self, sdk: ContinueSDK):
- self.description = ""
+ self.description = f"## {self.user_input}\n\n"
for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):
self.description += chunk
await sdk.update_ui()