diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/poetry.lock | 2 | ||||
-rw-r--r-- | continuedev/pyproject.toml | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 30 | ||||
-rw-r--r-- | continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 2 |
9 files changed, 45 insertions, 10 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index 4aedce87..93aaf82b 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -1737,4 +1737,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0f5f759bac0e44a1fbcc9babeccdea8688ea2226a4bae7a13858542ae03a3228" +content-hash = "17910714e3ad780ae7222b62c98539489d198aea67e5c7e4a9fc7672207f500f" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 7315e79d..af6ff045 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -20,6 +20,7 @@ websockets = "^11.0.2" urllib3 = "1.26.15" gpt-index = "^0.6.8" posthog = "^3.0.1" +tiktoken = "^0.4.0" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index ebd1fdb8..a54a823b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -19,7 +19,8 @@ class ContinueConfig(BaseModel): steps_on_startup: Optional[Dict[str, Dict]] = {} server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", "gpt-4", "starcoder"] + default_model: Literal["gpt-3.5-turbo", + "gpt-4", "starcoder"] = 'gpt-3.5-turbo' slash_commands: Optional[List[SlashCommand]] = [ # SlashCommand( # name="pytest", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index a54048b0..1da190ff 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -52,9 +52,9 @@ class Models: def __model_from_name(self, model_name: str): if model_name == "starcoder": return self.starcoder - elif model_name == "gpt35": + elif model_name == "gpt-3.5-turbo": return self.gpt35 - elif model_name == "gpt4": + elif model_name == "gpt-4": return self.gpt4 else: raise Exception(f"Unknown model {model_name}") diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 24fd34be..4889c01e 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -22,3 +22,7 @@ class LLM(BaseModel): def with_system_message(self, system_message: Union[str, None]): """Return a new model with the given system message.""" raise NotImplementedError + + def count_tokens(self, text: str): + """Return the number of tokens in the given text.""" + raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 9b8d3447..39c0b69f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,4 +1,5 @@ import asyncio +from functools import cached_property import time from typing import Any, Dict, Generator, List, Union from ...core.main import ChatMessage @@ -6,20 +7,44 @@ import openai import aiohttp from ..llm import LLM from pydantic import BaseModel, validator +import tiktoken +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4097, + "gpt-4": 4097, +} DEFAULT_MAX_TOKENS = 2048 class OpenAI(LLM): api_key: str completion_count: int = 0 - default_model: str = "text-davinci-003" + default_model: str = "gpt-3.5-turbo" @validator("api_key", pre=True, always=True) def validate_api_key(cls, v): openai.api_key = v return v + @cached_property + def __encoding_for_model(self): + aliases = { + "gpt-3.5-turbo": "gpt3" + } + return tiktoken.encoding_for_model(self.default_model) + + def count_tokens(self, text: str): + return len(self.__encoding_for_model.encode(text)) + + def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): + tokens = tokens_for_completion + for i in range(len(chat_history) - 1, -1, -1): + message = chat_history[i] + tokens += self.count_tokens(message.content) + if tokens > max_tokens: + return chat_history[i + 1:] + return chat_history + def with_system_message(self, system_message: Union[str, None]): return OpenAI(api_key=self.api_key, system_message=system_message) @@ -40,6 +65,8 @@ class OpenAI(LLM): continue def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: + msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( + prompt) + 1000 + self.count_tokens(self.system_message or "")) history = [] if self.system_message: history.append({ @@ -51,6 +78,7 @@ class OpenAI(LLM): "role": "user", "content": prompt }) + return history def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py index 6426df87..6db9fd4b 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py @@ -82,7 +82,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = sdk.models.gpt4.complete(dedent(f"""\ + suggestion = sdk.models.default.complete(dedent(f"""\ When trying to load data into BigQuery, the following error occurred: ```ascii diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py index 2f346b49..d9bdbc0a 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py @@ -65,9 +65,10 @@ class DeployAirflowStep(Step): async def run(self, sdk: ContinueSDK): # Run dlt command to deploy pipeline to Airflow - await sdk.run([ - f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer', - ], description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") + await sdk.run( + ['git init', + f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer'], + description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") # Get filepaths, open the DAG file directory = await sdk.ide.getWorkspaceDirectory() diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index aca0ab2a..7cfe7e0c 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -10,7 +10,7 @@ class SimpleChatStep(Step): name: str = "Chat" async def run(self, sdk: ContinueSDK): - self.description = "" + self.description = f"## {self.user_input}\n\n" for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): self.description += chunk await sdk.update_ui() |