diff options
| author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-07-31 11:36:19 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-31 11:36:19 -0700 | 
| commit | 457c9940ec6bdabd89de84a23abbf246aaf662c4 (patch) | |
| tree | bfc8bad5f8f9c2d933dbe815b80ae25e779d40e6 /continuedev/src | |
| parent | 269a28a311c38e60c720dfcfc2889a2f6f0f85bb (diff) | |
| parent | 43080aad995dcfb4b1742627fc03af7027cdbf8a (diff) | |
| download | sncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.tar.gz sncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.tar.bz2 sncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.zip | |
Merge pull request #318 from lun-4/llm-object-config
make Config receive LLM objects
Diffstat (limited to 'continuedev/src')
29 files changed, 507 insertions, 426 deletions
| diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 94d7be10..e048f877 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):          pass      @abstractmethod -    async def get_user_secret(self, env_var: str, prompt: str) -> str: +    async def get_user_secret(self, env_var: str) -> str:          pass      config: ContinueConfig diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 5ab5f8ae..de95a259 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -96,7 +96,6 @@ class Autopilot(ContinueBaseModel):              history=self.history,              active=self._active,              user_input_queue=self._main_user_input_queue, -            default_model=self.continue_sdk.config.default_model,              slash_commands=self.get_available_slash_commands(),              adding_highlighted_code=self.context_manager.context_providers[                  "code"].adding_highlighted_code, diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index fe0946cd..84b6b10b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,3 +1,9 @@ +import json +import os +from .main import Step +from .context import ContextProvider +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from .models import Models  from pydantic import BaseModel, validator  from typing import List, Literal, Optional, Dict, Type @@ -23,13 +29,6 @@ class OnTracebackSteps(BaseModel):      params: Optional[Dict] = {} -class OpenAIServerInfo(BaseModel): -    api_base: Optional[str] = None -    engine: Optional[str] = None -    api_version: Optional[str] = None -    api_type: Literal["azure", "openai"] = "openai" - -  class ContinueConfig(BaseModel):      """      A pydantic class for the continue config file. @@ -37,8 +36,9 @@ class ContinueConfig(BaseModel):      steps_on_startup: List[Step] = []      disallowed_steps: Optional[List[str]] = []      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4", "claude-2", "ggml"] = 'gpt-4' +    models: Models = Models( +        default=MaybeProxyOpenAI(model="gpt-4"), +    )      temperature: Optional[float] = 0.5      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test", @@ -48,7 +48,6 @@ class ContinueConfig(BaseModel):      slash_commands: Optional[List[SlashCommand]] = []      on_traceback: Optional[List[OnTracebackSteps]] = []      system_message: Optional[str] = None -    openai_server_info: Optional[OpenAIServerInfo] = None      policy_override: Optional[Policy] = None      context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 86522ce1..e968c35c 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -178,12 +178,6 @@ class ContextManager:                  except Exception as e:                      logger.debug(f"Error loading meilisearch index: {e}") -    # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: -    #     """ -    #     Compiles the chat prompt into a single string. -    #     """ -    #     return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) -      async def select_context_item(self, id: str, query: str):          """          Selects the ContextItem with the given id. diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index df9b98ef..2553850f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):      history: History      active: bool      user_input_queue: List[str] -    default_model: str      slash_commands: List[SlashCommandDescription]      adding_highlighted_code: bool      selected_context_items: List[ContextItem] diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py new file mode 100644 index 00000000..900762b6 --- /dev/null +++ b/continuedev/src/continuedev/core/models.py @@ -0,0 +1,65 @@ +from typing import Optional, Any +from pydantic import BaseModel, validator +from ..libs.llm import LLM + + +class Models(BaseModel): +    """Main class that holds the current model configuration""" +    default: LLM +    small: Optional[LLM] = None +    medium: Optional[LLM] = None +    large: Optional[LLM] = None + +    # TODO namespace these away to not confuse readers, +    # or split Models into ModelsConfig, which gets turned into Models +    sdk: "ContinueSDK" = None +    system_message: Any = None + +    """ +    Better to have sdk.llm.stream_chat(messages, model="claude-2"). +    Then you also don't care that it' async. +    And it's easier to add more models. +    And intermediate shared code is easier to add. +    And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" +    PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. +    Easy to reason about, can place anywhere. +    And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. +    This can all happen inside of Models? + +    class Prompt: +        def __init__(self, ...info): +            '''take whatever info is needed to describe the prompt''' + +        def to_string(self, model: str) -> str: +            '''depending on the model, return the single prompt string''' +    """ + +    async def start(self, sdk: "ContinueSDK"): +        """Start each of the LLMs, or fall back to default""" +        self.sdk = sdk +        self.system_message = self.sdk.config.system_message +        await sdk.start_model(self.default) +        if self.small: +            await sdk.start_model(self.small) +        else: +            self.small = self.default + +        if self.medium: +            await sdk.start_model(self.medium) +        else: +            self.medium = self.default + +        if self.large: +            await sdk.start_model(self.large) +        else: +            self.large = self.default + +    async def stop(self, sdk: "ContinueSDK"): +        """Stop each LLM (if it's not the default, which is shared)""" +        await self.default.stop() +        if self.small is not self.default: +            await self.small.stop() +        if self.medium is not self.default: +            await self.medium.stop() +        if self.large is not self.default: +            await self.large.stop() diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 50b89a8b..bf22d696 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK  from .config import ContinueConfig  from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory  from ..models.filesystem import RangeInFile -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI -from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import AnthropicLLM -from ..libs.llm.ggml import GGML +from ..libs.llm import LLM  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer  from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage  from ..plugins.steps.core.core import * -from ..libs.llm.proxy_server import ProxyServer  from ..libs.util.telemetry import posthog_logger  from ..libs.util.paths import getConfigFilePath +from .models import Models  from ..libs.util.logging import logger @@ -27,121 +24,6 @@ class Autopilot:      pass -ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] -MODEL_PROVIDER_TO_ENV_VAR = { -    "openai": "OPENAI_API_KEY", -    "hf_inference_api": "HUGGING_FACE_TOKEN", -    "anthropic": "ANTHROPIC_API_KEY", -} - - -class Models: -    provider_keys: Dict[ModelProvider, str] = {} -    model_providers: List[ModelProvider] -    system_message: str - -    """ -    Better to have sdk.llm.stream_chat(messages, model="claude-2"). -    Then you also don't care that it' async. -    And it's easier to add more models. -    And intermediate shared code is easier to add. -    And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" -    PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. -    Easy to reason about, can place anywhere. -    And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. -    This can all happen inside of Models? - -    class Prompt: -        def __init__(self, ...info): -            '''take whatever info is needed to describe the prompt''' - -        def to_string(self, model: str) -> str: -            '''depending on the model, return the single prompt string''' -    """ - -    def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): -        self.sdk = sdk -        self.model_providers = model_providers -        self.system_message = sdk.config.system_message - -    @classmethod -    async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": -        if sdk.config.default_model == "claude-2": -            with_providers.append("anthropic") - -        models = Models(sdk, with_providers) -        for provider in with_providers: -            if provider in MODEL_PROVIDER_TO_ENV_VAR: -                env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] -                models.provider_keys[provider] = await sdk.get_user_secret( -                    env_var, f'Please add your {env_var} to the .env file') - -        return models - -    def __load_openai_model(self, model: str) -> OpenAI: -        api_key = self.provider_keys["openai"] -        if api_key == "": -            return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) -        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log) - -    def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: -        api_key = self.provider_keys["hf_inference_api"] -        return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) - -    def __load_anthropic_model(self, model: str) -> AnthropicLLM: -        api_key = self.provider_keys["anthropic"] -        return AnthropicLLM(api_key, model, self.system_message) - -    @cached_property -    def claude2(self): -        return self.__load_anthropic_model("claude-2") - -    @cached_property -    def starcoder(self): -        return self.__load_hf_inference_api_model("bigcode/starcoder") - -    @cached_property -    def gpt35(self): -        return self.__load_openai_model("gpt-3.5-turbo") - -    @cached_property -    def gpt350613(self): -        return self.__load_openai_model("gpt-3.5-turbo-0613") - -    @cached_property -    def gpt3516k(self): -        return self.__load_openai_model("gpt-3.5-turbo-16k") - -    @cached_property -    def gpt4(self): -        return self.__load_openai_model("gpt-4") - -    @cached_property -    def ggml(self): -        return GGML(system_message=self.system_message) - -    def __model_from_name(self, model_name: str): -        if model_name == "starcoder": -            return self.starcoder -        elif model_name == "gpt-3.5-turbo": -            return self.gpt35 -        elif model_name == "gpt-3.5-turbo-16k": -            return self.gpt3516k -        elif model_name == "gpt-4": -            return self.gpt4 -        elif model_name == "claude-2": -            return self.claude2 -        elif model_name == "ggml": -            return self.ggml -        else: -            raise Exception(f"Unknown model {model_name}") - -    @property -    def default(self): -        default_model = self.sdk.config.default_model -        return self.__model_from_name(default_model) if default_model is not None else self.gpt4 - -  class ContinueSDK(AbstractContinueSDK):      """The SDK provided as parameters to a step"""      ide: AbstractIdeProtocolServer @@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):                  active=False              )) +        sdk.models = sdk.config.models +        await sdk.models.start(sdk) +          # When the config is loaded, setup posthog logger          posthog_logger.setup(              sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) -        sdk.models = await Models.create(sdk)          return sdk      @property @@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):      def write_log(self, message: str):          self.history.timeline[self.history.current_index].logs.append(message) +    async def start_model(self, llm: LLM): +        kwargs = {} +        if llm.requires_api_key: +            kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) +        if llm.requires_unique_id: +            kwargs["unique_id"] = self.ide.unique_id +        if llm.requires_write_log: +            kwargs["write_log"] = self.write_log +        await llm.start(**kwargs) +      async def _ensure_absolute_path(self, path: str) -> str:          if os.path.isabs(path):              return path @@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):          path = await self._ensure_absolute_path(path)          return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) -    async def get_user_secret(self, env_var: str, prompt: str) -> str: +    async def get_user_secret(self, env_var: str) -> str: +        # TODO support error prompt dynamically set on env_var          return await self.ide.getUserSecret(env_var)      _last_valid_config: ContinueConfig = None diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 69fd357b..cf8b0324 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -9,9 +9,11 @@ import subprocess  from continuedev.core.main import Step  from continuedev.core.sdk import ContinueSDK +from continuedev.core.models import Models  from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig  from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider  from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI  from continuedev.plugins.policies.default import DefaultPolicy  from continuedev.plugins.steps.open_config import OpenConfigStep @@ -36,9 +38,9 @@ class CommitMessageStep(Step):          diff = subprocess.check_output(              ["git", "diff"], cwd=dir).decode("utf-8") -        # Ask gpt-3.5-16k to write a commit message, +        # Ask the LLM to write a commit message,          # and set it as the description of this step -        self.description = await sdk.models.gpt3516k.complete( +        self.description = await sdk.models.default.complete(              f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") @@ -48,9 +50,10 @@ config = ContinueConfig(      # See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry      allow_anonymous_telemetry=True, -    # GPT-4 is recommended for best results -    # See options here: https://continue.dev/docs/customization#change-the-default-llm -    default_model="gpt-4", +    models=Models( +        default=MaybeProxyOpenAI(model="gpt-4"), +        medium=MaybeProxyOpenAI(model="gpt-3.5-turbo") +    ),      # Set a system message with information that the LLM should always keep in mind      # E.g. "Please give concise answers. Always respond in Spanish." diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..50577993 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,14 +1,30 @@ -from abc import ABC -from typing import Any, Coroutine, Dict, Generator, List, Union +from abc import ABC, abstractproperty +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional  from ...core.main import ChatMessage -from ...models.main import AbstractModel -from pydantic import BaseModel +from ...models.main import ContinueBaseModel -class LLM(ABC): +class LLM(ContinueBaseModel, ABC): +    requires_api_key: Optional[str] = None +    requires_unique_id: bool = False +    requires_write_log: bool = False +      system_message: Union[str, None] = None +    @abstractproperty +    def name(self): +        """Return the name of the LLM.""" +        raise NotImplementedError + +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        """Start the connection to the LLM.""" +        raise NotImplementedError + +    async def stop(self): +        """Stop the connection to the LLM.""" +        raise NotImplementedError +      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          """Return the completion of the text with the given temperature."""          raise NotImplementedError @@ -24,3 +40,8 @@ class LLM(ABC):      def count_tokens(self, text: str):          """Return the number of tokens in the given text."""          raise NotImplementedError + +    @abstractproperty +    def context_length(self) -> int: +        """Return the context length of the LLM in tokens, as counted by count_tokens.""" +        raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..ec1b7e40 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,32 +1,39 @@  from functools import cached_property  import time -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens  class AnthropicLLM(LLM): -    api_key: str -    default_model: str -    async_client: AsyncAnthropic +    model: str = "claude-2" -    def __init__(self, api_key: str, default_model: str, system_message: str = None): -        self.api_key = api_key -        self.default_model = default_model +    requires_api_key: str = "ANTHROPIC_API_KEY" +    _async_client: AsyncAnthropic = None + +    class Config: +        arbitrary_types_allowed = True + +    def __init__(self, model: str, system_message: str = None): +        self.model = model          self.system_message = system_message -        self.async_client = AsyncAnthropic(api_key=api_key) +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        self._async_client = AsyncAnthropic(api_key=api_key) + +    async def stop(self): +        pass      @cached_property      def name(self): -        return self.default_model +        return self.model      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.model}      def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:          args = args.copy() @@ -40,7 +47,13 @@ class AnthropicLLM(LLM):          return args      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text) + +    @property +    def context_length(self): +        if self.model == "claude-2": +            return 100000 +        raise Exception(f"Unknown Anthropic model {self.model}")      def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:          prompt = "" @@ -60,7 +73,7 @@ class AnthropicLLM(LLM):          args["stream"] = True          args = self._transform_args(args) -        async for chunk in await self.async_client.completions.create( +        async for chunk in await self._async_client.completions.create(              prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",              **args          ): @@ -73,8 +86,8 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) -        async for chunk in await self.async_client.completions.create( +            args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) +        async for chunk in await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          ): @@ -88,8 +101,8 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) -        resp = (await self.async_client.completions.create( +            args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) +        resp = (await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          )).completion diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 21374359..7742e8c3 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,6 +1,7 @@  from functools import cached_property  import json  from typing import Any, Coroutine, Dict, Generator, List, Union +from pydantic import ConfigDict  import aiohttp  from ...core.main import ChatMessage @@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"  class GGML(LLM): +    # this is model-specific +    max_context_length: int = 2048 -    def __init__(self, system_message: str = None): -        self.system_message = system_message +    _client_session: aiohttp.ClientSession = None -    @cached_property +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, **kwargs): +        self._client_session = aiohttp.ClientSession() + +    async def stop(self): +        await self._client_session.close() + +    @property      def name(self):          return "ggml"      @property +    def context_length(self): +        return self.max_context_length + +    @property      def default_args(self):          return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -33,54 +48,51 @@ class GGML(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            yield line.decode("utf-8") -                        except: -                            raise Exception(str(line)) +            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        yield line.decode("utf-8") +                    except: +                        raise Exception(str(line))      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          args["stream"] = True -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                # This is streaming application/json instaed of text/event-stream -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): -                                continue  # hehe -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    yield json.loads(chunk[6:])["choices"][0]["delta"] -                        except: -                            raise Exception(str(line[0])) +        async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            # This is streaming application/json instaed of text/event-stream +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                            continue +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                yield json.loads(chunk[6:])["choices"][0]["delta"] +                    except: +                        raise Exception(str(line[0]))      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), -                **args -            }) as resp: -                try: -                    return await resp.text() -                except: -                    raise Exception(await resp.text()) +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +            **args +        }) as resp: +            try: +                return await resp.text() +            except: +                raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..49f593d8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional  from ...core.main import ChatMessage  from ..llm import LLM  import requests @@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.  class HuggingFaceInferenceAPI(LLM): -    api_key: str      model: str -    def __init__(self, api_key: str, model: str, system_message: str = None): -        self.api_key = api_key +    requires_api_key: str = "HUGGING_FACE_TOKEN" +    api_key: str = None + +    def __init__(self, model: str, system_message: str = None):          self.model = model          self.system_message = system_message  # TODO: Nothing being done with this +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        self.api_key = api_key +      def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py new file mode 100644 index 00000000..edf58fd7 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -0,0 +1,53 @@ +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable + +from ...core.main import ChatMessage +from . import LLM +from .proxy_server import ProxyServer +from .openai import OpenAI + + +class MaybeProxyOpenAI(LLM): +    model: str + +    requires_api_key: Optional[str] = "OPENAI_API_KEY" +    requires_write_log: bool = True +    requires_unique_id: bool = True +    system_message: Union[str, None] = None + +    llm: Optional[LLM] = None + +    @property +    def name(self): +        return self.llm.name + +    @property +    def context_length(self): +        return self.llm.context_length + +    async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): +        if api_key is None or api_key.strip() == "": +            self.llm = ProxyServer(model=self.model) +        else: +            self.llm = OpenAI(model=self.model) + +        await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id) + +    async def stop(self): +        await self.llm.stop() + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +        return await self.llm.complete(prompt, with_history=with_history, **kwargs) + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        resp = self.llm.stream_complete( +            prompt, with_history=with_history, **kwargs) +        async for item in resp: +            yield item + +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        resp = self.llm.stream_chat(messages=messages, **kwargs) +        async for item in resp: +            yield item + +    def count_tokens(self, text: str): +        return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index d1ca4ef9..fce6e8ab 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,48 +1,83 @@  from functools import cached_property  import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional -from ...core.main import ChatMessage +from pydantic import BaseModel  import openai + +from ...core.main import ChatMessage +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import OpenAIServerInfo + + +class OpenAIServerInfo(BaseModel): +    api_base: Optional[str] = None +    engine: Optional[str] = None +    api_version: Optional[str] = None +    api_type: Literal["azure", "openai"] = "openai" + + +CHAT_MODELS = { +    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4096, +    "gpt-3.5-turbo-0613": 4096, +    "gpt-3.5-turbo-16k": 16384, +    "gpt-4": 8192, +} + + +class AzureInfo(BaseModel): +    endpoint: str +    engine: str +    api_version: str  class OpenAI(LLM): -    api_key: str -    default_model: str +    model: str +    openai_server_info: Optional[OpenAIServerInfo] = None -    def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None): -        self.api_key = api_key -        self.default_model = default_model -        self.system_message = system_message -        self.openai_server_info = openai_server_info +    requires_api_key = "OPENAI_API_KEY" +    requires_write_log = True + +    system_message: Optional[str] = None +    azure_info: Optional[AzureInfo] = None +    write_log: Optional[Callable[[str], None]] = None +    api_key: str = None + +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):          self.write_log = write_log +        self.api_key = api_key +        openai.api_key = self.api_key -        openai.api_key = api_key +        if self.openai_server_info is not None: +            openai.api_type = self.openai_server_info.api_type +            if self.openai_server_info.api_base is not None: +                openai.api_base = self.openai_server_info.api_base +            if self.openai_server_info.api_version is not None: +                openai.api_version = self.openai_server_info.api_version -        # Using an Azure OpenAI deployment -        if openai_server_info is not None: -            openai.api_type = openai_server_info.api_type -            if openai_server_info.api_base is not None: -                openai.api_base = openai_server_info.api_base -            if openai_server_info.api_version is not None: -                openai.api_version = openai_server_info.api_version +    async def stop(self): +        pass -    @cached_property +    @property      def name(self): -        return self.default_model +        return self.model + +    @property +    def context_length(self): +        return MAX_TOKENS_FOR_MODEL[self.model]      @property      def default_args(self): -        args = {**DEFAULT_ARGS, "model": self.default_model} +        args = {**DEFAULT_ARGS, "model": self.model}          if self.openai_server_info is not None:              args["engine"] = self.openai_server_info.engine          return args      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text)      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -51,7 +86,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              completion = ""              async for chunk in await openai.ChatCompletion.acreate( @@ -78,12 +113,13 @@ class OpenAI(LLM):          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True -        args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" +        # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? +        args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"          if not args["model"].endswith("0613") and "functions" in args:              del args["functions"]          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          completion = ""          async for chunk in await openai.ChatCompletion.acreate( @@ -100,7 +136,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              resp = (await openai.ChatCompletion.acreate(                  messages=messages, @@ -109,7 +145,7 @@ class OpenAI(LLM):              self.write_log(f"Completion: \n\n{resp}")          else:              prompt = prune_raw_prompt_from_top( -                args["model"], prompt, args["max_tokens"]) +                args["model"], self.context_length, prompt, args["max_tokens"])              self.write_log(f"Prompt:\n\n{prompt}")              resp = (await openai.Completion.acreate(                  prompt=prompt, diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..1a48f213 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,7 +1,6 @@ -  import json  import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional  import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM @@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)  # SERVER_URL = "http://127.0.0.1:8080"  SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4096, +    "gpt-3.5-turbo-0613": 4096, +    "gpt-3.5-turbo-16k": 16384, +    "gpt-4": 8192, +} +  class ProxyServer(LLM): -    unique_id: str -    name: str -    default_model: Literal["gpt-3.5-turbo", "gpt-4"] -    write_log: Callable[[str], None] +    model: str +    system_message: Optional[str] -    def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): -        self.unique_id = unique_id -        self.default_model = default_model -        self.system_message = system_message -        self.name = default_model +    unique_id: str = None +    write_log: Callable[[str], None] = None +    _client_session: aiohttp.ClientSession + +    requires_unique_id = True +    requires_write_log = True + +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): +        self._client_session = aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(ssl_context=ssl_context))          self.write_log = write_log +        self.unique_id = unique_id + +    async def stop(self): +        await self._client_session.close() + +    @property +    def name(self): +        return self.model + +    @property +    def context_length(self): +        return MAX_TOKENS_FOR_MODEL[self.model]      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.model}      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text)      def get_headers(self):          # headers with unique id @@ -45,75 +69,72 @@ class ProxyServer(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                if resp.status != 200: -                    raise Exception(await resp.text()) - -                response_text = await resp.text() -                self.write_log(f"Completion: \n\n{response_text}") -                return response_text +        async with self._client_session.post(f"{SERVER_URL}/complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            if resp.status != 200: +                raise Exception(await resp.text()) + +            response_text = await resp.text() +            self.write_log(f"Completion: \n\n{response_text}") +            return response_text      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/stream_chat", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                # This is streaming application/json instaed of text/event-stream -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            json_chunk = "{}" if json_chunk == "" else json_chunk -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    loaded_chunk = json.loads(chunk) -                                    yield loaded_chunk -                                    if "content" in loaded_chunk: -                                        completion += loaded_chunk["content"] -                        except Exception as e: -                            posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { -                                "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) -                    else: -                        break - -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            # This is streaming application/json instaed of text/event-stream +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        json_chunk = "{}" if json_chunk == "" else json_chunk +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                loaded_chunk = json.loads(chunk) +                                yield loaded_chunk +                                if "content" in loaded_chunk: +                                    completion += loaded_chunk["content"] +                    except Exception as e: +                        posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { +                            "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) +                else: +                    break + +            self.write_log(f"Completion: \n\n{completion}")      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) +            self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/stream_complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            decoded_line = line.decode("utf-8") -                            yield decoded_line -                            completion += decoded_line -                        except: -                            raise Exception(str(line)) -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        decoded_line = line.decode("utf-8") +                        yield decoded_line +                        completion += decoded_line +                    except: +                        raise Exception(str(line)) +            self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py deleted file mode 100644 index 76240d4e..00000000 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GPT2TokenizerFast - -gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") -def count_tokens(text: str) -> int: -    return len(gpt2_tokenizer.encode(text)) - -prices = { -    # All prices are per 1k tokens -    "fine-tune-train": { -        "davinci": 0.03, -        "curie": 0.03, -        "babbage": 0.0006, -        "ada": 0.0004, -    }, -    "completion": { -        "davinci": 0.02, -        "curie": 0.002, -        "babbage": 0.0005, -        "ada": 0.0004, -    }, -    "fine-tune-completion": { -        "davinci": 0.12, -        "curie": 0.012, -        "babbage": 0.0024, -        "ada": 0.0016, -    }, -    "embedding": { -        "ada": 0.0004 -    } -} - -def get_price(text: str, model: str="davinci", task: str="completion") -> float: -    return count_tokens(text) * prices[task][model] / 1000
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index c58ae499..6add7b1a 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -2,43 +2,47 @@ import json  from typing import Dict, List, Union  from ...core.main import ChatMessage  from .templating import render_templated_string +from ...libs.llm import LLM  import tiktoken +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA)  aliases = {      "ggml": "gpt-3.5-turbo",      "claude-2": "gpt-3.5-turbo",  }  DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4096, -    "gpt-3.5-turbo-0613": 4096, -    "gpt-3.5-turbo-16k": 16384, -    "gpt-4": 8192, -    "ggml": 2048, -    "claude-2": 100000 -} -CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -}  DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,                  "frequency_penalty": 0, "presence_penalty": 0} -def encoding_for_model(model: str): -    return tiktoken.encoding_for_model(aliases.get(model, model)) +def encoding_for_model(model_name: str): +    try: +        return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) +    except: +        return tiktoken.encoding_for_model("gpt-3.5-turbo") -def count_tokens(model: str, text: Union[str, None]): +def count_tokens(model_name: str, text: Union[str, None]):      if text is None:          return 0 -    encoding = encoding_for_model(model) +    encoding = encoding_for_model(model_name)      return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): -    max_tokens = MAX_TOKENS_FOR_MODEL.get( -        model, DEFAULT_MAX_TOKENS) - tokens_for_completion -    encoding = encoding_for_model(model) +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: +    # Doing simpler, safer version of what is here: +    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +    # every message follows <|start|>{role/name}\n{content}<|end|>\n +    TOKENS_PER_MESSAGE = 4 +    return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): +    max_tokens = context_length - tokens_for_completion +    encoding = encoding_for_model(model_name)      tokens = encoding.encode(prompt, disallowed_special=())      if len(tokens) <= max_tokens:          return prompt @@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in          return encoding.decode(tokens[-max_tokens:]) -def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: -    # Doing simpler, safer version of what is here: -    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -    # every message follows <|start|>{role/name}\n{content}<|end|>\n -    TOKENS_PER_MESSAGE = 4 -    return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE - - -def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):      total_tokens = tokens_for_completion + \ -        sum(count_chat_message_tokens(model, message) +        sum(count_chat_message_tokens(model_name, message)              for message in chat_history)      # 1. Replace beyond last 5 messages with summary      i = 0 -    while total_tokens > max_tokens and i < len(chat_history) - 5: +    while total_tokens > context_length and i < len(chat_history) - 5:          message = chat_history[0] -        total_tokens -= count_tokens(model, message.content) -        total_tokens += count_tokens(model, message.summary) +        total_tokens -= count_tokens(model_name, message.content) +        total_tokens += count_tokens(model_name, message.summary)          message.content = message.summary          i += 1      # 2. Remove entire messages until the last 5 -    while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0: +    while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:          message = chat_history.pop(0) -        total_tokens -= count_tokens(model, message.content) +        total_tokens -= count_tokens(model_name, message.content)      # 3. Truncate message in the last 5, except last 1      i = 0 -    while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: +    while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:          message = chat_history[i] -        total_tokens -= count_tokens(model, message.content) -        total_tokens += count_tokens(model, message.summary) +        total_tokens -= count_tokens(model_name, message.content) +        total_tokens += count_tokens(model_name, message.summary)          message.content = message.summary          i += 1      # 4. Remove entire messages in the last 5, except last 1 -    while total_tokens > max_tokens and len(chat_history) > 1: +    while total_tokens > context_length and len(chat_history) > 1:          message = chat_history.pop(0) -        total_tokens -= count_tokens(model, message.content) +        total_tokens -= count_tokens(model_name, message.content)      # 5. Truncate last message -    if total_tokens > max_tokens and len(chat_history) > 0: +    if total_tokens > context_length and len(chat_history) > 0:          message = chat_history[0]          message.content = prune_raw_prompt_from_top( -            model, message.content, tokens_for_completion) -        total_tokens = max_tokens +            model_name, context_length, message.content, tokens_for_completion) +        total_tokens = context_length      return chat_history @@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:  TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:      """      The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message      """ @@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_      function_tokens = 0      if functions is not None:          for function in functions: -            function_tokens += count_tokens(model, json.dumps(function)) +            function_tokens += count_tokens(model_name, json.dumps(function))      msgs_copy = prune_chat_history( -        model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) +        model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)      history = [msg.to_dict(with_functions=functions is not None)                 for msg in msgs_copy] diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 433e309e..872f8d62 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -27,7 +27,7 @@ class SetupPipelineStep(Step):      async def run(self, sdk: ContinueSDK):          sdk.context.set("api_description", self.api_description) -        source_name = (await sdk.models.gpt35.complete( +        source_name = (await sdk.models.medium.complete(              f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()          filename = f'{source_name}.py' @@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):          if "Traceback" in output or "SyntaxError" in output:              output = "Traceback" + output.split("Traceback")[-1]              file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) -            suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete(dedent(f"""\                  ```python                  {file_content}                  ``` @@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):                  This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) -            api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ +            api_documentation_url = await sdk.models.medium.complete(dedent(f"""\                  The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:                  ```python                         {file_content} @@ -151,7 +151,7 @@ class RunQueryStep(Step):          output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)          if "Traceback" in output or "SyntaxError" in output: -            suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete(dedent(f"""\                  ```python                  {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}                  ``` diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 6ef5ffd6..c66cd629 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):              "{self.user_input}"              Here is a complete set of pytest unit tests:""") -        tests = await sdk.models.gpt35.complete(prompt) +        tests = await sdk.models.medium.complete(prompt)          await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index 12073835..3f2f804c 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th  - Return a static string  - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.  Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index f72a8ec0..455d5a13 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep  from ...core.main import FunctionCall, Models  from ...core.main import ChatMessage, Step, step_to_json_schema  from ...core.sdk import ContinueSDK +from ...libs.llm.openai import OpenAI  import openai  import os  from dotenv import load_dotenv @@ -41,7 +42,7 @@ class SimpleChatStep(Step):                  self.description += chunk["content"]                  await sdk.update_ui() -        self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( +        self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(              f"Write a short title for the following chat message: {self.description}"))          self.chat_context.append(ChatMessage( @@ -166,7 +167,10 @@ class ChatWithFunctions(Step):              msg_content = ""              msg_step = None -            async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): +            gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") +            await sdk.start_model(gpt350613) + +            async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):                  if sdk.current_step_was_deleted():                      return diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index dbe8363e..658cc7f3 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):              Here is the answer:""") -        answer = await sdk.models.gpt35.complete(prompt) +        answer = await sdk.models.medium.complete(prompt)          # Make paths relative to the workspace directory          answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index de7cf3ac..4476c7ae 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -11,11 +11,12 @@ from pydantic import validator  from ....libs.llm.ggml import GGML  from ....models.main import Range +from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI  from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit  from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step -from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation +from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep +from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS  from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes @@ -97,7 +98,7 @@ class ShellCommandsStep(Step):              return f"Error when running shell commands:\n```\n{self._err_text}\n```"          cmds_str = "\n".join(self.cmds) -        return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") +        return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd @@ -105,7 +106,7 @@ class ShellCommandsStep(Step):          for cmd in self.cmds:              output = await sdk.ide.runCommand(cmd)              if self.handle_error and output is not None and output_contains_error(output): -                suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +                suggestion = await sdk.models.medium.complete(dedent(f"""\                      While running the command `{cmd}`, the following error occurred:                      ```ascii @@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):          else:              changes = '\n'.join(difflib.ndiff(                  self._previous_contents.splitlines(), self._new_contents.splitlines())) -            description = await models.gpt3516k.complete(dedent(f"""\ +            description = await models.medium.complete(dedent(f"""\                  Diff summary: "{self.user_input}"                  ```diff @@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):                  ```                  Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) -        name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") +        name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")          self.name = remove_quotes_and_escapes(name)          return f"{remove_quotes_and_escapes(description)}" @@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.          model_to_use = sdk.models.default -        max_tokens = int(MAX_TOKENS_FOR_MODEL.get( -            model_to_use.name, DEFAULT_MAX_TOKENS) / 2) +        max_tokens = int(model_to_use.context_length / 2)          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):          # If using 3.5 and overflows, upgrade to 3.5.16k          if model_to_use.name == "gpt-3.5-turbo": -            if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: -                model_to_use = sdk.models.gpt3516k +            if total_tokens > model_to_use.context_length: +                model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") +                await sdk.start_model(model_to_use)          # Remove tokens from the end first, and then the start to clear space          # This part finds the start and end lines @@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):          cur_start_line = 0          cur_end_line = len(full_file_contents_lst) - 1 -        if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +        if total_tokens > model_to_use.context_length:              while cur_end_line > min_end_line:                  total_tokens -= model_to_use.count_tokens(                      full_file_contents_lst[cur_end_line])                  cur_end_line -= 1 -                if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                if total_tokens < model_to_use.context_length:                      break -        if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +        if total_tokens > model_to_use.context_length:              while cur_start_line < max_start_line:                  cur_start_line += 1                  total_tokens -= model_to_use.count_tokens(                      full_file_contents_lst[cur_start_line]) -                if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                if total_tokens < model_to_use.context_length:                      break          # Now use the found start/end lines to get the prefix and suffix strings diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index a76d491b..c38f54dc 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step):          recent_edits = await sdk.ide.get_recent_edits(self.edited_file)          recent_edits_string = "\n\n".join(              map(lambda x: x.to_string(), recent_edits)) -        description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") +        description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")          await sdk.run([              "cd libs",              "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 6997a547..ec670999 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -56,7 +56,7 @@ class HelpStep(Step):                  summary="Help"              ))              messages = await sdk.get_chat_context() -            generator = sdk.models.gpt4.stream_chat(messages) +            generator = sdk.models.default.stream_chat(messages)              async for chunk in generator:                  if "content" in chunk:                      self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index b54d394a..3d8d96fb 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step):          if first_try is not None:              return first_try -        gpt_parsed = await sdk.models.gpt35.complete( +        gpt_parsed = await sdk.models.default.complete(              f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")          return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index f28d9660..d2d6f4dd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):          for rif in range_in_files:              rif_dict[rif.filepath] = rif.contents -        completion = await sdk.models.gpt35.complete(prompt) +        completion = await sdk.models.medium.complete(prompt)          # Temporarily doing this to generate description.          self._prompt = prompt @@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):      _prompt_and_completion: str = ""      async def describe(self, models: Models) -> Coroutine[str, None, None]: -        return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") +        return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          range_in_files = await sdk.get_code_context(only_editing=True) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 8b2e7c2e..da6acdbf 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step):              Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") -        resp = (await sdk.models.gpt35.complete(prompt)).lower() +        resp = (await sdk.models.medium.complete(prompt)).lower()          step_to_run = None          for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 522a84a3..456dba84 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):      async def run(self, sdk: ContinueSDK):          # Ask the user for a regex pattern -        pattern = await sdk.models.gpt35.complete(dedent(f"""\ +        pattern = await sdk.models.medium.complete(dedent(f"""\              This is the user request:              {self.user_request} | 
