diff options
Diffstat (limited to 'continuedev/src')
7 files changed, 43 insertions, 9 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index ebd1fdb8..a54a823b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -19,7 +19,8 @@ class ContinueConfig(BaseModel):      steps_on_startup: Optional[Dict[str, Dict]] = {}      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", "gpt-4", "starcoder"] +    default_model: Literal["gpt-3.5-turbo", +                           "gpt-4", "starcoder"] = 'gpt-3.5-turbo'      slash_commands: Optional[List[SlashCommand]] = [          # SlashCommand(          #     name="pytest", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index a54048b0..1da190ff 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -52,9 +52,9 @@ class Models:      def __model_from_name(self, model_name: str):          if model_name == "starcoder":              return self.starcoder -        elif model_name == "gpt35": +        elif model_name == "gpt-3.5-turbo":              return self.gpt35 -        elif model_name == "gpt4": +        elif model_name == "gpt-4":              return self.gpt4          else:              raise Exception(f"Unknown model {model_name}") diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 24fd34be..4889c01e 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -22,3 +22,7 @@ class LLM(BaseModel):      def with_system_message(self, system_message: Union[str, None]):          """Return a new model with the given system message."""          raise NotImplementedError + +    def count_tokens(self, text: str): +        """Return the number of tokens in the given text.""" +        raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 9b8d3447..39c0b69f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,4 +1,5 @@  import asyncio +from functools import cached_property  import time  from typing import Any, Dict, Generator, List, Union  from ...core.main import ChatMessage @@ -6,20 +7,44 @@ import openai  import aiohttp  from ..llm import LLM  from pydantic import BaseModel, validator +import tiktoken +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4097, +    "gpt-4": 4097, +}  DEFAULT_MAX_TOKENS = 2048  class OpenAI(LLM):      api_key: str      completion_count: int = 0 -    default_model: str = "text-davinci-003" +    default_model: str = "gpt-3.5-turbo"      @validator("api_key", pre=True, always=True)      def validate_api_key(cls, v):          openai.api_key = v          return v +    @cached_property +    def __encoding_for_model(self): +        aliases = { +            "gpt-3.5-turbo": "gpt3" +        } +        return tiktoken.encoding_for_model(self.default_model) + +    def count_tokens(self, text: str): +        return len(self.__encoding_for_model.encode(text)) + +    def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +        tokens = tokens_for_completion +        for i in range(len(chat_history) - 1, -1, -1): +            message = chat_history[i] +            tokens += self.count_tokens(message.content) +            if tokens > max_tokens: +                return chat_history[i + 1:] +        return chat_history +      def with_system_message(self, system_message: Union[str, None]):          return OpenAI(api_key=self.api_key, system_message=system_message) @@ -40,6 +65,8 @@ class OpenAI(LLM):                  continue      def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: +        msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( +            prompt) + 1000 + self.count_tokens(self.system_message or ""))          history = []          if self.system_message:              history.append({ @@ -51,6 +78,7 @@ class OpenAI(LLM):              "role": "user",              "content": prompt          }) +          return history      def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py index 6426df87..6db9fd4b 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py @@ -82,7 +82,7 @@ class LoadDataStep(Step):                  docs = f.read()              output = "Traceback" + output.split("Traceback")[-1] -            suggestion = sdk.models.gpt4.complete(dedent(f"""\ +            suggestion = sdk.models.default.complete(dedent(f"""\                  When trying to load data into BigQuery, the following error occurred:                  ```ascii diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py index 2f346b49..d9bdbc0a 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py @@ -65,9 +65,10 @@ class DeployAirflowStep(Step):      async def run(self, sdk: ContinueSDK):          # Run dlt command to deploy pipeline to Airflow -        await sdk.run([ -            f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer', -        ], description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") +        await sdk.run( +            ['git init', +                f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer'], +            description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow")          # Get filepaths, open the DAG file          directory = await sdk.ide.getWorkspaceDirectory() diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index aca0ab2a..7cfe7e0c 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -10,7 +10,7 @@ class SimpleChatStep(Step):      name: str = "Chat"      async def run(self, sdk: ContinueSDK): -        self.description = "" +        self.description = f"## {self.user_input}\n\n"          for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):              self.description += chunk              await sdk.update_ui() | 
