diff options
Diffstat (limited to 'continuedev/src')
42 files changed, 669 insertions, 490 deletions
| diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 94d7be10..e048f877 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):          pass      @abstractmethod -    async def get_user_secret(self, env_var: str, prompt: str) -> str: +    async def get_user_secret(self, env_var: str) -> str:          pass      config: ContinueConfig diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 57e39d5c..de95a259 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents  from .observation import Observation, InternalErrorObservation  from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy  from ..plugins.context_providers.file import FileContextProvider  from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider  from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str:  class Autopilot(ContinueBaseModel): -    policy: Policy      ide: AbstractIdeProtocolServer + +    policy: Policy = DefaultPolicy()      history: History = History.from_empty()      context: Context = Context()      full_state: Union[FullState, None] = None @@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel):      _user_input_queue = AsyncSubscriptionQueue()      _retry_queue = AsyncSubscriptionQueue() -    @classmethod -    async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": -        autopilot = cls(ide=ide, policy=policy) -        autopilot.continue_sdk = await ContinueSDK.create(autopilot) +    async def start(self): +        self.continue_sdk = await ContinueSDK.create(self) +        if override_policy := self.continue_sdk.config.policy_override: +            self.policy = override_policy          # Load documents into the search index -        autopilot.context_manager = await ContextManager.create( -            autopilot.continue_sdk.config.context_providers + [ -                HighlightedCodeContextProvider(ide=ide), -                FileContextProvider(workspace_dir=ide.workspace_directory) +        self.context_manager = await ContextManager.create( +            self.continue_sdk.config.context_providers + [ +                HighlightedCodeContextProvider(ide=self.ide), +                FileContextProvider(workspace_dir=self.ide.workspace_directory)              ]) -        await autopilot.context_manager.load_index(ide.workspace_directory) -        return autopilot +        await self.context_manager.load_index(self.ide.workspace_directory)      class Config:          arbitrary_types_allowed = True @@ -95,7 +96,6 @@ class Autopilot(ContinueBaseModel):              history=self.history,              active=self._active,              user_input_queue=self._main_user_input_queue, -            default_model=self.continue_sdk.config.default_model,              slash_commands=self.get_available_slash_commands(),              adding_highlighted_code=self.context_manager.context_providers[                  "code"].adding_highlighted_code, diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 9fbda824..84b6b10b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,9 +2,13 @@ import json  import os  from .main import Step  from .context import ContextProvider +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from .models import Models  from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider  class SlashCommand(BaseModel): @@ -25,13 +29,6 @@ class OnTracebackSteps(BaseModel):      params: Optional[Dict] = {} -class OpenAIServerInfo(BaseModel): -    api_base: Optional[str] = None -    engine: Optional[str] = None -    api_version: Optional[str] = None -    api_type: Literal["azure", "openai"] = "openai" - -  class ContinueConfig(BaseModel):      """      A pydantic class for the continue config file. @@ -39,8 +36,9 @@ class ContinueConfig(BaseModel):      steps_on_startup: List[Step] = []      disallowed_steps: Optional[List[str]] = []      allow_anonymous_telemetry: Optional[bool] = True -    default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4", "claude-2", "ggml"] = 'gpt-4' +    models: Models = Models( +        default=MaybeProxyOpenAI(model="gpt-4"), +    )      temperature: Optional[float] = 0.5      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test", @@ -50,7 +48,7 @@ class ContinueConfig(BaseModel):      slash_commands: Optional[List[SlashCommand]] = []      on_traceback: Optional[List[OnTracebackSteps]] = []      system_message: Optional[str] = None -    openai_server_info: Optional[OpenAIServerInfo] = None +    policy_override: Optional[Policy] = None      context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 86522ce1..e968c35c 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -178,12 +178,6 @@ class ContextManager:                  except Exception as e:                      logger.debug(f"Error loading meilisearch index: {e}") -    # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: -    #     """ -    #     Compiles the chat prompt into a single string. -    #     """ -    #     return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) -      async def select_context_item(self, id: str, query: str):          """          Selects the ContextItem with the given id. diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index df9b98ef..2553850f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):      history: History      active: bool      user_input_queue: List[str] -    default_model: str      slash_commands: List[SlashCommandDescription]      adding_highlighted_code: bool      selected_context_items: List[ContextItem] diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py new file mode 100644 index 00000000..900762b6 --- /dev/null +++ b/continuedev/src/continuedev/core/models.py @@ -0,0 +1,65 @@ +from typing import Optional, Any +from pydantic import BaseModel, validator +from ..libs.llm import LLM + + +class Models(BaseModel): +    """Main class that holds the current model configuration""" +    default: LLM +    small: Optional[LLM] = None +    medium: Optional[LLM] = None +    large: Optional[LLM] = None + +    # TODO namespace these away to not confuse readers, +    # or split Models into ModelsConfig, which gets turned into Models +    sdk: "ContinueSDK" = None +    system_message: Any = None + +    """ +    Better to have sdk.llm.stream_chat(messages, model="claude-2"). +    Then you also don't care that it' async. +    And it's easier to add more models. +    And intermediate shared code is easier to add. +    And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" +    PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. +    Easy to reason about, can place anywhere. +    And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. +    This can all happen inside of Models? + +    class Prompt: +        def __init__(self, ...info): +            '''take whatever info is needed to describe the prompt''' + +        def to_string(self, model: str) -> str: +            '''depending on the model, return the single prompt string''' +    """ + +    async def start(self, sdk: "ContinueSDK"): +        """Start each of the LLMs, or fall back to default""" +        self.sdk = sdk +        self.system_message = self.sdk.config.system_message +        await sdk.start_model(self.default) +        if self.small: +            await sdk.start_model(self.small) +        else: +            self.small = self.default + +        if self.medium: +            await sdk.start_model(self.medium) +        else: +            self.medium = self.default + +        if self.large: +            await sdk.start_model(self.large) +        else: +            self.large = self.default + +    async def stop(self, sdk: "ContinueSDK"): +        """Stop each LLM (if it's not the default, which is shared)""" +        await self.default.stop() +        if self.small is not self.default: +            await self.small.stop() +        if self.medium is not self.default: +            await self.medium.stop() +        if self.large is not self.default: +            await self.large.stop() diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 4b76a121..bf22d696 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK  from .config import ContinueConfig  from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory  from ..models.filesystem import RangeInFile -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI -from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import AnthropicLLM -from ..libs.llm.ggml import GGML +from ..libs.llm import LLM  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer  from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage  from ..plugins.steps.core.core import * -from ..libs.llm.proxy_server import ProxyServer  from ..libs.util.telemetry import posthog_logger  from ..libs.util.paths import getConfigFilePath +from .models import Models  from ..libs.util.logging import logger @@ -27,121 +24,6 @@ class Autopilot:      pass -ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] -MODEL_PROVIDER_TO_ENV_VAR = { -    "openai": "OPENAI_API_KEY", -    "hf_inference_api": "HUGGING_FACE_TOKEN", -    "anthropic": "ANTHROPIC_API_KEY", -} - - -class Models: -    provider_keys: Dict[ModelProvider, str] = {} -    model_providers: List[ModelProvider] -    system_message: str - -    """ -    Better to have sdk.llm.stream_chat(messages, model="claude-2"). -    Then you also don't care that it' async. -    And it's easier to add more models. -    And intermediate shared code is easier to add. -    And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" -    PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. -    Easy to reason about, can place anywhere. -    And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. -    This can all happen inside of Models? - -    class Prompt: -        def __init__(self, ...info): -            '''take whatever info is needed to describe the prompt''' - -        def to_string(self, model: str) -> str: -            '''depending on the model, return the single prompt string''' -    """ - -    def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): -        self.sdk = sdk -        self.model_providers = model_providers -        self.system_message = sdk.config.system_message - -    @classmethod -    async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": -        if sdk.config.default_model == "claude-2": -            with_providers.append("anthropic") - -        models = Models(sdk, with_providers) -        for provider in with_providers: -            if provider in MODEL_PROVIDER_TO_ENV_VAR: -                env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] -                models.provider_keys[provider] = await sdk.get_user_secret( -                    env_var, f'Please add your {env_var} to the .env file') - -        return models - -    def __load_openai_model(self, model: str) -> OpenAI: -        api_key = self.provider_keys["openai"] -        if api_key == "": -            return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) -        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log) - -    def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: -        api_key = self.provider_keys["hf_inference_api"] -        return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) - -    def __load_anthropic_model(self, model: str) -> AnthropicLLM: -        api_key = self.provider_keys["anthropic"] -        return AnthropicLLM(api_key, model, self.system_message) - -    @cached_property -    def claude2(self): -        return self.__load_anthropic_model("claude-2") - -    @cached_property -    def starcoder(self): -        return self.__load_hf_inference_api_model("bigcode/starcoder") - -    @cached_property -    def gpt35(self): -        return self.__load_openai_model("gpt-3.5-turbo") - -    @cached_property -    def gpt350613(self): -        return self.__load_openai_model("gpt-3.5-turbo-0613") - -    @cached_property -    def gpt3516k(self): -        return self.__load_openai_model("gpt-3.5-turbo-16k") - -    @cached_property -    def gpt4(self): -        return self.__load_openai_model("gpt-4") - -    @cached_property -    def ggml(self): -        return GGML(system_message=self.system_message) - -    def __model_from_name(self, model_name: str): -        if model_name == "starcoder": -            return self.starcoder -        elif model_name == "gpt-3.5-turbo": -            return self.gpt35 -        elif model_name == "gpt-3.5-turbo-16k": -            return self.gpt3516k -        elif model_name == "gpt-4": -            return self.gpt4 -        elif model_name == "claude-2": -            return self.claude2 -        elif model_name == "ggml": -            return self.ggml -        else: -            raise Exception(f"Unknown model {model_name}") - -    @property -    def default(self): -        default_model = self.sdk.config.default_model -        return self.__model_from_name(default_model) if default_model is not None else self.gpt4 - -  class ContinueSDK(AbstractContinueSDK):      """The SDK provided as parameters to a step"""      ide: AbstractIdeProtocolServer @@ -171,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK):              formatted_err = '\n'.join(traceback.format_exception(e))              msg_step = MessageStep(                  name="Invalid Continue Config File", message=formatted_err) -            msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```" +            msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)."              sdk.history.add_node(HistoryNode(                  step=msg_step,                  observation=None, @@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):                  active=False              )) +        sdk.models = sdk.config.models +        await sdk.models.start(sdk) +          # When the config is loaded, setup posthog logger          posthog_logger.setup(              sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) -        sdk.models = await Models.create(sdk)          return sdk      @property @@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):      def write_log(self, message: str):          self.history.timeline[self.history.current_index].logs.append(message) +    async def start_model(self, llm: LLM): +        kwargs = {} +        if llm.requires_api_key: +            kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) +        if llm.requires_unique_id: +            kwargs["unique_id"] = self.ide.unique_id +        if llm.requires_write_log: +            kwargs["write_log"] = self.write_log +        await llm.start(**kwargs) +      async def _ensure_absolute_path(self, path: str) -> str:          if os.path.isabs(path):              return path @@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):          path = await self._ensure_absolute_path(path)          return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) -    async def get_user_secret(self, env_var: str, prompt: str) -> str: +    async def get_user_secret(self, env_var: str) -> str: +        # TODO support error prompt dynamically set on env_var          return await self.ide.getUserSecret(env_var)      _last_valid_config: ContinueConfig = None diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index f09b813a..dba4874f 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -59,7 +59,7 @@ class ChromaIndexManager:              except:                  logger.warning(                      f"ERROR (probably found special token): {doc.text}") -                continue +                continue  # lol              filename = doc.extra_info["filename"]              chunks[filename] = len(text_chunks)              for i, text in enumerate(text_chunks): diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index 23ed950f..d5326a06 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str):      """Further filter files before indexing."""      for file in files:          if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): -            continue +            continue  # nice          yield root_dir + "/" + file diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 1a66c847..cf8b0324 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -9,9 +9,12 @@ import subprocess  from continuedev.core.main import Step  from continuedev.core.sdk import ContinueSDK +from continuedev.core.models import Models  from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig  from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider  from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.plugins.policies.default import DefaultPolicy  from continuedev.plugins.steps.open_config import OpenConfigStep  from continuedev.plugins.steps.clear_history import ClearHistoryStep @@ -35,9 +38,9 @@ class CommitMessageStep(Step):          diff = subprocess.check_output(              ["git", "diff"], cwd=dir).decode("utf-8") -        # Ask gpt-3.5-16k to write a commit message, +        # Ask the LLM to write a commit message,          # and set it as the description of this step -        self.description = await sdk.models.gpt3516k.complete( +        self.description = await sdk.models.default.complete(              f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") @@ -47,9 +50,10 @@ config = ContinueConfig(      # See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry      allow_anonymous_telemetry=True, -    # GPT-4 is recommended for best results -    # See options here: https://continue.dev/docs/customization#change-the-default-llm -    default_model="gpt-4", +    models=Models( +        default=MaybeProxyOpenAI(model="gpt-4"), +        medium=MaybeProxyOpenAI(model="gpt-3.5-turbo") +    ),      # Set a system message with information that the LLM should always keep in mind      # E.g. "Please give concise answers. Always respond in Spanish." @@ -114,5 +118,9 @@ config = ContinueConfig(          # GoogleContextProvider(          #     serper_api_key="<your serper.dev api key>"          # ) -    ] +    ], + +    # Policies hold the main logic that decides which Step to take next +    # You can use them to design agents, or deeply customize Continue +    policy=DefaultPolicy()  ) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..50577993 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,14 +1,30 @@ -from abc import ABC -from typing import Any, Coroutine, Dict, Generator, List, Union +from abc import ABC, abstractproperty +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional  from ...core.main import ChatMessage -from ...models.main import AbstractModel -from pydantic import BaseModel +from ...models.main import ContinueBaseModel -class LLM(ABC): +class LLM(ContinueBaseModel, ABC): +    requires_api_key: Optional[str] = None +    requires_unique_id: bool = False +    requires_write_log: bool = False +      system_message: Union[str, None] = None +    @abstractproperty +    def name(self): +        """Return the name of the LLM.""" +        raise NotImplementedError + +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        """Start the connection to the LLM.""" +        raise NotImplementedError + +    async def stop(self): +        """Stop the connection to the LLM.""" +        raise NotImplementedError +      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          """Return the completion of the text with the given temperature."""          raise NotImplementedError @@ -24,3 +40,8 @@ class LLM(ABC):      def count_tokens(self, text: str):          """Return the number of tokens in the given text."""          raise NotImplementedError + +    @abstractproperty +    def context_length(self) -> int: +        """Return the context length of the LLM in tokens, as counted by count_tokens.""" +        raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..ec1b7e40 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,32 +1,39 @@  from functools import cached_property  import time -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens  class AnthropicLLM(LLM): -    api_key: str -    default_model: str -    async_client: AsyncAnthropic +    model: str = "claude-2" -    def __init__(self, api_key: str, default_model: str, system_message: str = None): -        self.api_key = api_key -        self.default_model = default_model +    requires_api_key: str = "ANTHROPIC_API_KEY" +    _async_client: AsyncAnthropic = None + +    class Config: +        arbitrary_types_allowed = True + +    def __init__(self, model: str, system_message: str = None): +        self.model = model          self.system_message = system_message -        self.async_client = AsyncAnthropic(api_key=api_key) +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        self._async_client = AsyncAnthropic(api_key=api_key) + +    async def stop(self): +        pass      @cached_property      def name(self): -        return self.default_model +        return self.model      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.model}      def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:          args = args.copy() @@ -40,7 +47,13 @@ class AnthropicLLM(LLM):          return args      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text) + +    @property +    def context_length(self): +        if self.model == "claude-2": +            return 100000 +        raise Exception(f"Unknown Anthropic model {self.model}")      def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:          prompt = "" @@ -60,7 +73,7 @@ class AnthropicLLM(LLM):          args["stream"] = True          args = self._transform_args(args) -        async for chunk in await self.async_client.completions.create( +        async for chunk in await self._async_client.completions.create(              prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",              **args          ): @@ -73,8 +86,8 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) -        async for chunk in await self.async_client.completions.create( +            args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) +        async for chunk in await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          ): @@ -88,8 +101,8 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) -        resp = (await self.async_client.completions.create( +            args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) +        resp = (await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          )).completion diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..7742e8c3 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,6 +1,7 @@  from functools import cached_property  import json  from typing import Any, Coroutine, Dict, Generator, List, Union +from pydantic import ConfigDict  import aiohttp  from ...core.main import ChatMessage @@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"  class GGML(LLM): +    # this is model-specific +    max_context_length: int = 2048 -    def __init__(self, system_message: str = None): -        self.system_message = system_message +    _client_session: aiohttp.ClientSession = None -    @cached_property +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, **kwargs): +        self._client_session = aiohttp.ClientSession() + +    async def stop(self): +        await self._client_session.close() + +    @property      def name(self):          return "ggml"      @property +    def context_length(self): +        return self.max_context_length + +    @property      def default_args(self):          return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -33,54 +48,51 @@ class GGML(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            yield line.decode("utf-8") -                        except: -                            raise Exception(str(line)) +            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        yield line.decode("utf-8") +                    except: +                        raise Exception(str(line))      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          args["stream"] = True -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                # This is streaming application/json instaed of text/event-stream -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): -                                continue -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    yield json.loads(chunk[6:])["choices"][0]["delta"] -                        except: -                            raise Exception(str(line[0])) +        async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            # This is streaming application/json instaed of text/event-stream +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                            continue +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                yield json.loads(chunk[6:])["choices"][0]["delta"] +                    except: +                        raise Exception(str(line[0]))      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), -                **args -            }) as resp: -                try: -                    return await resp.text() -                except: -                    raise Exception(await resp.text()) +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +            **args +        }) as resp: +            try: +                return await resp.text() +            except: +                raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..49f593d8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional  from ...core.main import ChatMessage  from ..llm import LLM  import requests @@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.  class HuggingFaceInferenceAPI(LLM): -    api_key: str      model: str -    def __init__(self, api_key: str, model: str, system_message: str = None): -        self.api_key = api_key +    requires_api_key: str = "HUGGING_FACE_TOKEN" +    api_key: str = None + +    def __init__(self, model: str, system_message: str = None):          self.model = model          self.system_message = system_message  # TODO: Nothing being done with this +    async def start(self, *, api_key: Optional[str] = None, **kwargs): +        self.api_key = api_key +      def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py new file mode 100644 index 00000000..edf58fd7 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -0,0 +1,53 @@ +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable + +from ...core.main import ChatMessage +from . import LLM +from .proxy_server import ProxyServer +from .openai import OpenAI + + +class MaybeProxyOpenAI(LLM): +    model: str + +    requires_api_key: Optional[str] = "OPENAI_API_KEY" +    requires_write_log: bool = True +    requires_unique_id: bool = True +    system_message: Union[str, None] = None + +    llm: Optional[LLM] = None + +    @property +    def name(self): +        return self.llm.name + +    @property +    def context_length(self): +        return self.llm.context_length + +    async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): +        if api_key is None or api_key.strip() == "": +            self.llm = ProxyServer(model=self.model) +        else: +            self.llm = OpenAI(model=self.model) + +        await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id) + +    async def stop(self): +        await self.llm.stop() + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +        return await self.llm.complete(prompt, with_history=with_history, **kwargs) + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        resp = self.llm.stream_complete( +            prompt, with_history=with_history, **kwargs) +        async for item in resp: +            yield item + +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        resp = self.llm.stream_chat(messages=messages, **kwargs) +        async for item in resp: +            yield item + +    def count_tokens(self, text: str): +        return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 654c7326..fce6e8ab 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,48 +1,83 @@  from functools import cached_property  import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional -from ...core.main import ChatMessage +from pydantic import BaseModel  import openai + +from ...core.main import ChatMessage +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import OpenAIServerInfo + + +class OpenAIServerInfo(BaseModel): +    api_base: Optional[str] = None +    engine: Optional[str] = None +    api_version: Optional[str] = None +    api_type: Literal["azure", "openai"] = "openai" + + +CHAT_MODELS = { +    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4096, +    "gpt-3.5-turbo-0613": 4096, +    "gpt-3.5-turbo-16k": 16384, +    "gpt-4": 8192, +} + + +class AzureInfo(BaseModel): +    endpoint: str +    engine: str +    api_version: str  class OpenAI(LLM): -    api_key: str -    default_model: str +    model: str +    openai_server_info: Optional[OpenAIServerInfo] = None -    def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None): -        self.api_key = api_key -        self.default_model = default_model -        self.system_message = system_message -        self.openai_server_info = openai_server_info +    requires_api_key = "OPENAI_API_KEY" +    requires_write_log = True + +    system_message: Optional[str] = None +    azure_info: Optional[AzureInfo] = None +    write_log: Optional[Callable[[str], None]] = None +    api_key: str = None + +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):          self.write_log = write_log +        self.api_key = api_key +        openai.api_key = self.api_key -        openai.api_key = api_key +        if self.openai_server_info is not None: +            openai.api_type = self.openai_server_info.api_type +            if self.openai_server_info.api_base is not None: +                openai.api_base = self.openai_server_info.api_base +            if self.openai_server_info.api_version is not None: +                openai.api_version = self.openai_server_info.api_version -        # Using an Azure OpenAI deployment -        if openai_server_info is not None: -            openai.api_type = openai_server_info.api_type -            if openai_server_info.api_base is not None: -                openai.api_base = openai_server_info.api_base -            if openai_server_info.api_version is not None: -                openai.api_version = openai_server_info.api_version +    async def stop(self): +        pass -    @cached_property +    @property      def name(self): -        return self.default_model +        return self.model + +    @property +    def context_length(self): +        return MAX_TOKENS_FOR_MODEL[self.model]      @property      def default_args(self): -        args = {**DEFAULT_ARGS, "model": self.default_model} +        args = {**DEFAULT_ARGS, "model": self.model}          if self.openai_server_info is not None:              args["engine"] = self.openai_server_info.engine          return args      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text)      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -51,7 +86,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              completion = ""              async for chunk in await openai.ChatCompletion.acreate( @@ -62,7 +97,7 @@ class OpenAI(LLM):                      yield chunk.choices[0].delta.content                      completion += chunk.choices[0].delta.content                  else: -                    continue +                    continue  # :)              self.write_log(f"Completion: \n\n{completion}")          else: @@ -78,12 +113,13 @@ class OpenAI(LLM):          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True -        args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" +        # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? +        args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"          if not args["model"].endswith("0613") and "functions" in args:              del args["functions"]          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          completion = ""          async for chunk in await openai.ChatCompletion.acreate( @@ -100,7 +136,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              resp = (await openai.ChatCompletion.acreate(                  messages=messages, @@ -109,7 +145,7 @@ class OpenAI(LLM):              self.write_log(f"Completion: \n\n{resp}")          else:              prompt = prune_raw_prompt_from_top( -                args["model"], prompt, args["max_tokens"]) +                args["model"], self.context_length, prompt, args["max_tokens"])              self.write_log(f"Prompt:\n\n{prompt}")              resp = (await openai.Completion.acreate(                  prompt=prompt, diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..1a48f213 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,7 +1,6 @@ -  import json  import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional  import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM @@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)  # SERVER_URL = "http://127.0.0.1:8080"  SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4096, +    "gpt-3.5-turbo-0613": 4096, +    "gpt-3.5-turbo-16k": 16384, +    "gpt-4": 8192, +} +  class ProxyServer(LLM): -    unique_id: str -    name: str -    default_model: Literal["gpt-3.5-turbo", "gpt-4"] -    write_log: Callable[[str], None] +    model: str +    system_message: Optional[str] -    def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): -        self.unique_id = unique_id -        self.default_model = default_model -        self.system_message = system_message -        self.name = default_model +    unique_id: str = None +    write_log: Callable[[str], None] = None +    _client_session: aiohttp.ClientSession + +    requires_unique_id = True +    requires_write_log = True + +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): +        self._client_session = aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(ssl_context=ssl_context))          self.write_log = write_log +        self.unique_id = unique_id + +    async def stop(self): +        await self._client_session.close() + +    @property +    def name(self): +        return self.model + +    @property +    def context_length(self): +        return MAX_TOKENS_FOR_MODEL[self.model]      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.model}      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.model, text)      def get_headers(self):          # headers with unique id @@ -45,75 +69,72 @@ class ProxyServer(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                if resp.status != 200: -                    raise Exception(await resp.text()) - -                response_text = await resp.text() -                self.write_log(f"Completion: \n\n{response_text}") -                return response_text +        async with self._client_session.post(f"{SERVER_URL}/complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            if resp.status != 200: +                raise Exception(await resp.text()) + +            response_text = await resp.text() +            self.write_log(f"Completion: \n\n{response_text}") +            return response_text      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/stream_chat", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                # This is streaming application/json instaed of text/event-stream -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            json_chunk = "{}" if json_chunk == "" else json_chunk -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    loaded_chunk = json.loads(chunk) -                                    yield loaded_chunk -                                    if "content" in loaded_chunk: -                                        completion += loaded_chunk["content"] -                        except Exception as e: -                            posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { -                                "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) -                    else: -                        break - -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            # This is streaming application/json instaed of text/event-stream +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        json_chunk = "{}" if json_chunk == "" else json_chunk +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                loaded_chunk = json.loads(chunk) +                                yield loaded_chunk +                                if "content" in loaded_chunk: +                                    completion += loaded_chunk["content"] +                    except Exception as e: +                        posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { +                            "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) +                else: +                    break + +            self.write_log(f"Completion: \n\n{completion}")      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) +            self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -            async with session.post(f"{SERVER_URL}/stream_complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            decoded_line = line.decode("utf-8") -                            yield decoded_line -                            completion += decoded_line -                        except: -                            raise Exception(str(line)) -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        decoded_line = line.decode("utf-8") +                        yield decoded_line +                        completion += decoded_line +                    except: +                        raise Exception(str(line)) +            self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py deleted file mode 100644 index 76240d4e..00000000 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GPT2TokenizerFast - -gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") -def count_tokens(text: str) -> int: -    return len(gpt2_tokenizer.encode(text)) - -prices = { -    # All prices are per 1k tokens -    "fine-tune-train": { -        "davinci": 0.03, -        "curie": 0.03, -        "babbage": 0.0006, -        "ada": 0.0004, -    }, -    "completion": { -        "davinci": 0.02, -        "curie": 0.002, -        "babbage": 0.0005, -        "ada": 0.0004, -    }, -    "fine-tune-completion": { -        "davinci": 0.12, -        "curie": 0.012, -        "babbage": 0.0024, -        "ada": 0.0016, -    }, -    "embedding": { -        "ada": 0.0004 -    } -} - -def get_price(text: str, model: str="davinci", task: str="completion") -> float: -    return count_tokens(text) * prices[task][model] / 1000
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index ff0a135f..3e82bab3 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit              tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index]              replacement = updated[j1:j2]              if tag == "equal": -                continue +                continue  # ;)              elif tag == "delete":                  edits.append(FileEdit.from_deletion(                      filepath, Range.from_indices(original, i1, i2))) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index c58ae499..6add7b1a 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -2,43 +2,47 @@ import json  from typing import Dict, List, Union  from ...core.main import ChatMessage  from .templating import render_templated_string +from ...libs.llm import LLM  import tiktoken +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA)  aliases = {      "ggml": "gpt-3.5-turbo",      "claude-2": "gpt-3.5-turbo",  }  DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4096, -    "gpt-3.5-turbo-0613": 4096, -    "gpt-3.5-turbo-16k": 16384, -    "gpt-4": 8192, -    "ggml": 2048, -    "claude-2": 100000 -} -CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -}  DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,                  "frequency_penalty": 0, "presence_penalty": 0} -def encoding_for_model(model: str): -    return tiktoken.encoding_for_model(aliases.get(model, model)) +def encoding_for_model(model_name: str): +    try: +        return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) +    except: +        return tiktoken.encoding_for_model("gpt-3.5-turbo") -def count_tokens(model: str, text: Union[str, None]): +def count_tokens(model_name: str, text: Union[str, None]):      if text is None:          return 0 -    encoding = encoding_for_model(model) +    encoding = encoding_for_model(model_name)      return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): -    max_tokens = MAX_TOKENS_FOR_MODEL.get( -        model, DEFAULT_MAX_TOKENS) - tokens_for_completion -    encoding = encoding_for_model(model) +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: +    # Doing simpler, safer version of what is here: +    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +    # every message follows <|start|>{role/name}\n{content}<|end|>\n +    TOKENS_PER_MESSAGE = 4 +    return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): +    max_tokens = context_length - tokens_for_completion +    encoding = encoding_for_model(model_name)      tokens = encoding.encode(prompt, disallowed_special=())      if len(tokens) <= max_tokens:          return prompt @@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in          return encoding.decode(tokens[-max_tokens:]) -def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: -    # Doing simpler, safer version of what is here: -    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -    # every message follows <|start|>{role/name}\n{content}<|end|>\n -    TOKENS_PER_MESSAGE = 4 -    return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE - - -def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):      total_tokens = tokens_for_completion + \ -        sum(count_chat_message_tokens(model, message) +        sum(count_chat_message_tokens(model_name, message)              for message in chat_history)      # 1. Replace beyond last 5 messages with summary      i = 0 -    while total_tokens > max_tokens and i < len(chat_history) - 5: +    while total_tokens > context_length and i < len(chat_history) - 5:          message = chat_history[0] -        total_tokens -= count_tokens(model, message.content) -        total_tokens += count_tokens(model, message.summary) +        total_tokens -= count_tokens(model_name, message.content) +        total_tokens += count_tokens(model_name, message.summary)          message.content = message.summary          i += 1      # 2. Remove entire messages until the last 5 -    while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0: +    while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:          message = chat_history.pop(0) -        total_tokens -= count_tokens(model, message.content) +        total_tokens -= count_tokens(model_name, message.content)      # 3. Truncate message in the last 5, except last 1      i = 0 -    while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: +    while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:          message = chat_history[i] -        total_tokens -= count_tokens(model, message.content) -        total_tokens += count_tokens(model, message.summary) +        total_tokens -= count_tokens(model_name, message.content) +        total_tokens += count_tokens(model_name, message.summary)          message.content = message.summary          i += 1      # 4. Remove entire messages in the last 5, except last 1 -    while total_tokens > max_tokens and len(chat_history) > 1: +    while total_tokens > context_length and len(chat_history) > 1:          message = chat_history.pop(0) -        total_tokens -= count_tokens(model, message.content) +        total_tokens -= count_tokens(model_name, message.content)      # 5. Truncate last message -    if total_tokens > max_tokens and len(chat_history) > 0: +    if total_tokens > context_length and len(chat_history) > 0:          message = chat_history[0]          message.content = prune_raw_prompt_from_top( -            model, message.content, tokens_for_completion) -        total_tokens = max_tokens +            model_name, context_length, message.content, tokens_for_completion) +        total_tokens = context_length      return chat_history @@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:  TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:      """      The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message      """ @@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_      function_tokens = 0      if functions is not None:          for function in functions: -            function_tokens += count_tokens(model, json.dumps(function)) +            function_tokens += count_tokens(model_name, json.dumps(function))      msgs_copy = prune_chat_history( -        model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) +        model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)      history = [msg.to_dict(with_functions=functions is not None)                 for msg in msgs_copy] diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index f1fb8d0b..285c1e47 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:      for i in range(1, len(lines)):          # Empty lines are wildcards          if lines[i].strip() == "": -            continue +            continue  # hey that's us!          # Iterate through the leading whitespace characters of the current line          for j in range(0, len(lcp)):              # If it doesn't have the same whitespace as lcp, then update lcp diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 51869fdd..2166bc37 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -39,7 +39,7 @@ def main():              json = schema_json_of(model, indent=2, title=title)          except Exception as e:              print(f"Failed to generate json schema for {title}: {e}") -            continue +            continue  # pun intended          with open(f"{SCHEMA_DIR}/{title}.json", "w") as f:              f.write(json) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..42d1f754 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional +import uuid +from pydantic import BaseModel + +from ...core.main import ContextItemId +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager +from .util import remove_meilisearch_disallowed_chars + + +class EmbeddingResult(BaseModel): +    filename: str +    content: str + + +class EmbeddingsProvider(ContextProvider): +    title = "embed" + +    workspace_directory: str + +    EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + +    index_manager: Optional[ChromaIndexManager] = None + +    class Config: +        arbitrary_types_allowed = True + +    @property +    def index(self): +        if self.index_manager is None: +            self.index_manager = ChromaIndexManager(self.workspace_directory) +        return self.index_manager + +    @property +    def BASE_CONTEXT_ITEM(self): +        return ContextItem( +            content="", +            description=ContextItemDescription( +                name="Embedding Search", +                description="Enter a query to embedding search codebase", +                id=ContextItemId( +                    provider_title=self.title, +                    item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID +                ) +            ) +        ) + +    async def _get_query_results(self, query: str) -> str: +        results = self.index.query_codebase_index(query) + +        ret = [] +        for node in results.source_nodes: +            resource_name = list(node.node.relationships.values())[0] +            filepath = resource_name[:resource_name.index("::")] +            ret.append(EmbeddingResult( +                filename=filepath, content=node.node.text)) + +        return ret + +    async def provide_context_items(self) -> List[ContextItem]: +        self.index.create_codebase_index()  # TODO Synchronous here is not ideal + +        return [self.BASE_CONTEXT_ITEM] + +    async def add_context_item(self, id: ContextItemId, query: str): +        if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: +            raise Exception("Invalid item id") + +        results = await self._get_query_results(query) + +        for i in range(len(results)): +            result = results[i] +            ctx_item = self.BASE_CONTEXT_ITEM.copy() +            ctx_item.description.name = os.path.basename(result.filename) +            ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" +            ctx_item.description.id.item_id = uuid.uuid4().hex +            self.selected_items.append(ctx_item) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 634774df..31aa5423 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -3,6 +3,7 @@ import re  from typing import List  from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider +from .util import remove_meilisearch_disallowed_chars  from fnmatch import fnmatch @@ -79,7 +80,7 @@ class FileContextProvider(ContextProvider):                      description=file,                      id=ContextItemId(                          provider_title=self.title, -                        item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file) +                        item_id=remove_meilisearch_disallowed_chars(file)                      )                  )              )) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index fc76fe67..4b0a59ec 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,6 +2,7 @@ import json  from typing import List  import aiohttp +from .util import remove_meilisearch_disallowed_chars  from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider @@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider):          ctx_item = self.BASE_CONTEXT_ITEM.copy()          ctx_item.content = content -        ctx_item.description.id.item_id = query +        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( +            query)          return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py new file mode 100644 index 00000000..da2e6b17 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -0,0 +1,5 @@ +import re + + +def remove_meilisearch_disallowed_chars(id: str) -> str: +    return re.sub(r'[^0-9a-zA-Z_-]', '', id) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/plugins/policies/default.py index d90177b5..523c2cf4 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,15 +1,15 @@  from textwrap import dedent  from typing import Union -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep +from ..steps.chat import SimpleChatStep +from ..steps.welcome import WelcomeStep +from ...core.config import ContinueConfig +from ..steps.steps_on_startup import StepsOnStartupStep +from ...core.main import Step, History, Policy +from ...core.observation import UserInputObservation +from ..steps.core.core import MessageStep +from ..steps.custom_command import CustomCommandStep +from ..steps.main import EditHighlightedCodeStep  def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:  class DefaultPolicy(Policy): -    ran_code_last: bool = False + +    default_step: Step = SimpleChatStep()      def next(self, config: ContinueConfig, history: History) -> Step:          # At the very start, run initial Steps spcecified in the config @@ -56,7 +57,6 @@ class DefaultPolicy(Policy):                      - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue                      - Use `/help` to ask questions about how to use Continue""")) >>                  WelcomeStep() >> -                # CreateCodebaseIndexChroma() >>                  StepsOnStartupStep())          observation = history.get_current().observation @@ -75,6 +75,6 @@ class DefaultPolicy(Policy):              if user_input.startswith("/edit"):                  return EditHighlightedCodeStep(user_input=user_input[5:]) -            return SimpleChatStep() +            return self.default_step.copy()          return None diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 433e309e..872f8d62 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -27,7 +27,7 @@ class SetupPipelineStep(Step):      async def run(self, sdk: ContinueSDK):          sdk.context.set("api_description", self.api_description) -        source_name = (await sdk.models.gpt35.complete( +        source_name = (await sdk.models.medium.complete(              f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()          filename = f'{source_name}.py' @@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):          if "Traceback" in output or "SyntaxError" in output:              output = "Traceback" + output.split("Traceback")[-1]              file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) -            suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete(dedent(f"""\                  ```python                  {file_content}                  ``` @@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):                  This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) -            api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ +            api_documentation_url = await sdk.models.medium.complete(dedent(f"""\                  The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:                  ```python                         {file_content} @@ -151,7 +151,7 @@ class RunQueryStep(Step):          output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)          if "Traceback" in output or "SyntaxError" in output: -            suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete(dedent(f"""\                  ```python                  {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}                  ``` diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 6ef5ffd6..c66cd629 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):              "{self.user_input}"              Here is a complete set of pytest unit tests:""") -        tests = await sdk.models.gpt35.complete(prompt) +        tests = await sdk.models.medium.complete(prompt)          await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index 12073835..3f2f804c 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th  - Return a static string  - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.  Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index f72a8ec0..455d5a13 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep  from ...core.main import FunctionCall, Models  from ...core.main import ChatMessage, Step, step_to_json_schema  from ...core.sdk import ContinueSDK +from ...libs.llm.openai import OpenAI  import openai  import os  from dotenv import load_dotenv @@ -41,7 +42,7 @@ class SimpleChatStep(Step):                  self.description += chunk["content"]                  await sdk.update_ui() -        self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( +        self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(              f"Write a short title for the following chat message: {self.description}"))          self.chat_context.append(ChatMessage( @@ -166,7 +167,10 @@ class ChatWithFunctions(Step):              msg_content = ""              msg_step = None -            async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): +            gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") +            await sdk.start_model(gpt350613) + +            async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):                  if sdk.current_step_was_deleted():                      return diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index dbe8363e..658cc7f3 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):              Here is the answer:""") -        answer = await sdk.models.gpt35.complete(prompt) +        answer = await sdk.models.medium.complete(prompt)          # Make paths relative to the workspace directory          answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index c80cecc3..4476c7ae 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -11,11 +11,12 @@ from pydantic import validator  from ....libs.llm.ggml import GGML  from ....models.main import Range +from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI  from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit  from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step -from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation +from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep +from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS  from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes @@ -97,7 +98,7 @@ class ShellCommandsStep(Step):              return f"Error when running shell commands:\n```\n{self._err_text}\n```"          cmds_str = "\n".join(self.cmds) -        return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") +        return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd @@ -105,7 +106,7 @@ class ShellCommandsStep(Step):          for cmd in self.cmds:              output = await sdk.ide.runCommand(cmd)              if self.handle_error and output is not None and output_contains_error(output): -                suggestion = await sdk.models.gpt35.complete(dedent(f"""\ +                suggestion = await sdk.models.medium.complete(dedent(f"""\                      While running the command `{cmd}`, the following error occurred:                      ```ascii @@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):          else:              changes = '\n'.join(difflib.ndiff(                  self._previous_contents.splitlines(), self._new_contents.splitlines())) -            description = await models.gpt3516k.complete(dedent(f"""\ +            description = await models.medium.complete(dedent(f"""\                  Diff summary: "{self.user_input}"                  ```diff @@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):                  ```                  Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) -        name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") +        name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")          self.name = remove_quotes_and_escapes(name)          return f"{remove_quotes_and_escapes(description)}" @@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.          model_to_use = sdk.models.default -        max_tokens = int(MAX_TOKENS_FOR_MODEL.get( -            model_to_use.name, DEFAULT_MAX_TOKENS) / 2) +        max_tokens = int(model_to_use.context_length / 2)          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):          # If using 3.5 and overflows, upgrade to 3.5.16k          if model_to_use.name == "gpt-3.5-turbo": -            if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: -                model_to_use = sdk.models.gpt3516k +            if total_tokens > model_to_use.context_length: +                model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") +                await sdk.start_model(model_to_use)          # Remove tokens from the end first, and then the start to clear space          # This part finds the start and end lines @@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):          cur_start_line = 0          cur_end_line = len(full_file_contents_lst) - 1 -        if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +        if total_tokens > model_to_use.context_length:              while cur_end_line > min_end_line:                  total_tokens -= model_to_use.count_tokens(                      full_file_contents_lst[cur_end_line])                  cur_end_line -= 1 -                if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                if total_tokens < model_to_use.context_length:                      break -        if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: +        if total_tokens > model_to_use.context_length:              while cur_start_line < max_start_line:                  cur_start_line += 1                  total_tokens -= model_to_use.count_tokens(                      full_file_contents_lst[cur_start_line]) -                if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: +                if total_tokens < model_to_use.context_length:                      break          # Now use the found start/end lines to get the prefix and suffix strings @@ -525,7 +526,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  # Accumulate lines                  if "content" not in chunk: -                    continue +                    continue  # ayo                  chunk = chunk["content"]                  chunk_lines = chunk.split("\n")                  chunk_lines[0] = unfinished_line + chunk_lines[0] @@ -546,12 +547,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user                          break                      # Lines that should be ignored, like the <> tags                      elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): -                        continue +                        continue  # noice                      # Check if we are currently just copying the prefix                      elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]:                          # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line                          lines_of_prefix_copied += 1 -                        continue +                        continue  # also nice                      # Because really short lines might be expected to be repeated, this is only a !heuristic!                      # Stop when it starts copying the file_suffix                      elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index a76d491b..c38f54dc 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step):          recent_edits = await sdk.ide.get_recent_edits(self.edited_file)          recent_edits_string = "\n\n".join(              map(lambda x: x.to_string(), recent_edits)) -        description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") +        description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")          await sdk.run([              "cd libs",              "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 6997a547..ec670999 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -56,7 +56,7 @@ class HelpStep(Step):                  summary="Help"              ))              messages = await sdk.get_chat_context() -            generator = sdk.models.gpt4.stream_chat(messages) +            generator = sdk.models.default.stream_chat(messages)              async for chunk in generator:                  if "content" in chunk:                      self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index b54d394a..3d8d96fb 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step):          if first_try is not None:              return first_try -        gpt_parsed = await sdk.models.gpt35.complete( +        gpt_parsed = await sdk.models.default.complete(              f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")          return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index f28d9660..d2d6f4dd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):          for rif in range_in_files:              rif_dict[rif.filepath] = rif.contents -        completion = await sdk.models.gpt35.complete(prompt) +        completion = await sdk.models.medium.complete(prompt)          # Temporarily doing this to generate description.          self._prompt = prompt @@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):      _prompt_and_completion: str = ""      async def describe(self, models: Models) -> Coroutine[str, None, None]: -        return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") +        return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          range_in_files = await sdk.get_code_context(only_editing=True) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 8b2e7c2e..da6acdbf 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step):              Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") -        resp = (await sdk.models.gpt35.complete(prompt)).lower() +        resp = (await sdk.models.medium.complete(prompt)).lower()          step_to_run = None          for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 07b50473..456dba84 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]:      for root, dirs, files in os.walk(dirpath):          dirname = os.path.basename(root)          if dirname.startswith(".") or dirname in IGNORE_DIRS: -            continue +            continue  # continue!          for file in files:              if file in IGNORE_FILES: -                continue +                continue  # pun intended              with open(os.path.join(root, file), "r") as f:                  # Find the index of all occurences of the pattern in the file. Use re.                  file_content = f.read() @@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):      async def run(self, sdk: ContinueSDK):          # Ask the user for a regex pattern -        pattern = await sdk.models.gpt35.complete(dedent(f"""\ +        pattern = await sdk.models.medium.complete(dedent(f"""\              This is the user request:              {self.user_request} diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 98a5aea0..cf18c56b 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we                  message = json.loads(message)              if "messageType" not in message or "data" not in message: -                continue +                continue  # :o              message_type = message["messageType"]              data = message["data"] diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e4c07029..6124f3bd 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              msg_string = await self.websocket.receive_text()              message = json.loads(msg_string)              if "messageType" not in message or "data" not in message: -                continue +                continue  # <-- hey that's the name of this repo!              message_type = message["messageType"]              data = message["data"]              logger.debug(f"Received message while initializing {message_type}") @@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      def onFileEdits(self, edits: List[FileEditWithFullContents]):          if autopilot := self.__get_autopilot(): -            autopilot.handle_manual_edits(edits) +            pass      def onDeleteAtIndex(self, index: int):          if autopilot := self.__get_autopilot(): diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index cf46028f..b5580fe8 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,4 +1,5 @@  import os +import traceback  from fastapi import WebSocket  from typing import Any, Dict, List, Union  from uuid import uuid4 @@ -6,12 +7,10 @@ import json  from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep  from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath  from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DefaultPolicy -from ..core.main import FullState +from ..core.main import FullState, HistoryNode  from ..core.autopilot import Autopilot  from .ide_protocol import AbstractIdeProtocolServer  from ..libs.util.create_async_task import create_async_task @@ -31,19 +30,6 @@ class Session:          self.ws = None -class DemoAutopilot(Autopilot): -    first_seen: bool = False -    cumulative_edit_string = "" - -    def handle_manual_edits(self, edits: List[FileEditWithFullContents]): -        return -        for edit in edits: -            self.cumulative_edit_string += edit.fileEdit.replacement -            self._manual_edits_buffer.append(edit) -            # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. -            # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - -  class SessionManager:      sessions: Dict[str, Session] = {}      # Mapping of session_id to IDE, where the IDE is still alive @@ -65,27 +51,47 @@ class SessionManager:      async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:          logger.debug(f"New session: {session_id}") +        # Load the persisted state (not being used right now)          full_state = None          if session_id is not None and os.path.exists(getSessionFilePath(session_id)):              with open(getSessionFilePath(session_id), "r") as f:                  full_state = FullState(**json.load(f)) -        autopilot = await DemoAutopilot.create( -            policy=DefaultPolicy(), ide=ide, full_state=full_state) +        # Register the session and ide (do this first so that the autopilot can access the session) +        autopilot = Autopilot(ide=ide)          session_id = session_id or str(uuid4())          ide.session_id = session_id          session = Session(session_id=session_id, autopilot=autopilot)          self.sessions[session_id] = session          self.registered_ides[session_id] = ide +        # Set up the autopilot to update the GUI          async def on_update(state: FullState):              await session_manager.send_ws_data(session_id, "state_update", {                  "state": state.dict()              })          autopilot.on_update(on_update) -        create_async_task(autopilot.run_policy( -        ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) + +        # Start the autopilot (must be after session is added to sessions) and the policy +        try: +            await autopilot.start() +        except Exception as e: +            # Have to manually add to history because autopilot isn't started +            formatted_err = '\n'.join(traceback.format_exception(e)) +            msg_step = MessageStep( +                name="Error loading context manager", message=formatted_err) +            msg_step.description = f"```\n{formatted_err}\n```" +            autopilot.history.add_node(HistoryNode( +                step=msg_step, +                observation=None, +                depth=0, +                active=False +            )) +            logger.warning(f"Error loading context manager: {e}") + +        create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step( +            DisplayErrorStep(e=e)))          return session      async def remove_session(self, session_id: str): | 
