summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-07-31 11:36:19 -0700
committerGitHub <noreply@github.com>2023-07-31 11:36:19 -0700
commit457c9940ec6bdabd89de84a23abbf246aaf662c4 (patch)
treebfc8bad5f8f9c2d933dbe815b80ae25e779d40e6 /continuedev/src
parent269a28a311c38e60c720dfcfc2889a2f6f0f85bb (diff)
parent43080aad995dcfb4b1742627fc03af7027cdbf8a (diff)
downloadsncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.tar.gz
sncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.tar.bz2
sncontinue-457c9940ec6bdabd89de84a23abbf246aaf662c4.zip
Merge pull request #318 from lun-4/llm-object-config
make Config receive LLM objects
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py2
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/core/config.py19
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/models.py65
-rw-r--r--continuedev/src/continuedev/core/sdk.py139
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt13
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py31
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py100
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py53
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py92
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py165
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py34
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py86
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/README.md2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py31
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
29 files changed, 507 insertions, 426 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 94d7be10..e048f877 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):
pass
@abstractmethod
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
pass
config: ContinueConfig
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 5ab5f8ae..de95a259 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -96,7 +96,6 @@ class Autopilot(ContinueBaseModel):
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
- default_model=self.continue_sdk.config.default_model,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code,
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index fe0946cd..84b6b10b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,3 +1,9 @@
+import json
+import os
+from .main import Step
+from .context import ContextProvider
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from .models import Models
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type
@@ -23,13 +29,6 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class OpenAIServerInfo(BaseModel):
- api_base: Optional[str] = None
- engine: Optional[str] = None
- api_version: Optional[str] = None
- api_type: Literal["azure", "openai"] = "openai"
-
-
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
@@ -37,8 +36,9 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ models: Models = Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ )
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -48,7 +48,6 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- openai_server_info: Optional[OpenAIServerInfo] = None
policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 86522ce1..e968c35c 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -178,12 +178,6 @@ class ContextManager:
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
- # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- # """
- # Compiles the chat prompt into a single string.
- # """
- # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
async def select_context_item(self, id: str, query: str):
"""
Selects the ContextItem with the given id.
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index df9b98ef..2553850f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):
history: History
active: bool
user_input_queue: List[str]
- default_model: str
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
new file mode 100644
index 00000000..900762b6
--- /dev/null
+++ b/continuedev/src/continuedev/core/models.py
@@ -0,0 +1,65 @@
+from typing import Optional, Any
+from pydantic import BaseModel, validator
+from ..libs.llm import LLM
+
+
+class Models(BaseModel):
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
+
+ # TODO namespace these away to not confuse readers,
+ # or split Models into ModelsConfig, which gets turned into Models
+ sdk: "ContinueSDK" = None
+ system_message: Any = None
+
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
+ async def start(self, sdk: "ContinueSDK"):
+ """Start each of the LLMs, or fall back to default"""
+ self.sdk = sdk
+ self.system_message = self.sdk.config.system_message
+ await sdk.start_model(self.default)
+ if self.small:
+ await sdk.start_model(self.small)
+ else:
+ self.small = self.default
+
+ if self.medium:
+ await sdk.start_model(self.medium)
+ else:
+ self.medium = self.default
+
+ if self.large:
+ await sdk.start_model(self.large)
+ else:
+ self.large = self.default
+
+ async def stop(self, sdk: "ContinueSDK"):
+ """Stop each LLM (if it's not the default, which is shared)"""
+ await self.default.stop()
+ if self.small is not self.default:
+ await self.small.stop()
+ if self.medium is not self.default:
+ await self.medium.stop()
+ if self.large is not self.default:
+ await self.large.stop()
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 50b89a8b..bf22d696 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
-from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import AnthropicLLM
-from ..libs.llm.ggml import GGML
+from ..libs.llm import LLM
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
from ..plugins.steps.core.core import *
-from ..libs.llm.proxy_server import ProxyServer
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
+from .models import Models
from ..libs.util.logging import logger
@@ -27,121 +24,6 @@ class Autopilot:
pass
-ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
-MODEL_PROVIDER_TO_ENV_VAR = {
- "openai": "OPENAI_API_KEY",
- "hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY",
-}
-
-
-class Models:
- provider_keys: Dict[ModelProvider, str] = {}
- model_providers: List[ModelProvider]
- system_message: str
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
-
- def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
- self.sdk = sdk
- self.model_providers = model_providers
- self.system_message = sdk.config.system_message
-
- @classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
-
- def __load_openai_model(self, model: str) -> OpenAI:
- api_key = self.provider_keys["openai"]
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
-
- def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
- api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
-
- def __load_anthropic_model(self, model: str) -> AnthropicLLM:
- api_key = self.provider_keys["anthropic"]
- return AnthropicLLM(api_key, model, self.system_message)
-
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
- @property
- def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
-
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
ide: AbstractIdeProtocolServer
@@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):
active=False
))
+ sdk.models = sdk.config.models
+ await sdk.models.start(sdk)
+
# When the config is loaded, setup posthog logger
posthog_logger.setup(
sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
- sdk.models = await Models.create(sdk)
return sdk
@property
@@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):
def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)
+ async def start_model(self, llm: LLM):
+ kwargs = {}
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
+ if llm.requires_unique_id:
+ kwargs["unique_id"] = self.ide.unique_id
+ if llm.requires_write_log:
+ kwargs["write_log"] = self.write_log
+ await llm.start(**kwargs)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
@@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
return await self.ide.getUserSecret(env_var)
_last_valid_config: ContinueConfig = None
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index 69fd357b..cf8b0324 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -9,9 +9,11 @@ import subprocess
from continuedev.core.main import Step
from continuedev.core.sdk import ContinueSDK
+from continuedev.core.models import Models
from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
from continuedev.plugins.context_providers.google import GoogleContextProvider
+from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from continuedev.plugins.policies.default import DefaultPolicy
from continuedev.plugins.steps.open_config import OpenConfigStep
@@ -36,9 +38,9 @@ class CommitMessageStep(Step):
diff = subprocess.check_output(
["git", "diff"], cwd=dir).decode("utf-8")
- # Ask gpt-3.5-16k to write a commit message,
+ # Ask the LLM to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.gpt3516k.complete(
+ self.description = await sdk.models.default.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
@@ -48,9 +50,10 @@ config = ContinueConfig(
# See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry
allow_anonymous_telemetry=True,
- # GPT-4 is recommended for best results
- # See options here: https://continue.dev/docs/customization#change-the-default-llm
- default_model="gpt-4",
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ ),
# Set a system message with information that the LLM should always keep in mind
# E.g. "Please give concise answers. Always respond in Spanish."
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2766db4b..50577993 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,14 +1,30 @@
-from abc import ABC
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from abc import ABC, abstractproperty
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
from ...core.main import ChatMessage
-from ...models.main import AbstractModel
-from pydantic import BaseModel
+from ...models.main import ContinueBaseModel
-class LLM(ABC):
+class LLM(ContinueBaseModel, ABC):
+ requires_api_key: Optional[str] = None
+ requires_unique_id: bool = False
+ requires_write_log: bool = False
+
system_message: Union[str, None] = None
+ @abstractproperty
+ def name(self):
+ """Return the name of the LLM."""
+ raise NotImplementedError
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ """Start the connection to the LLM."""
+ raise NotImplementedError
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ raise NotImplementedError
+
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
@@ -24,3 +40,8 @@ class LLM(ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError
+
+ @abstractproperty
+ def context_length(self) -> int:
+ """Return the context length of the LLM in tokens, as counted by count_tokens."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 625d4e57..ec1b7e40 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,32 +1,39 @@
from functools import cached_property
import time
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
class AnthropicLLM(LLM):
- api_key: str
- default_model: str
- async_client: AsyncAnthropic
+ model: str = "claude-2"
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
- self.default_model = default_model
+ requires_api_key: str = "ANTHROPIC_API_KEY"
+ _async_client: AsyncAnthropic = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, model: str, system_message: str = None):
+ self.model = model
self.system_message = system_message
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self._async_client = AsyncAnthropic(api_key=api_key)
+
+ async def stop(self):
+ pass
@cached_property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
@@ -40,7 +47,13 @@ class AnthropicLLM(LLM):
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
+
+ @property
+ def context_length(self):
+ if self.model == "claude-2":
+ return 100000
+ raise Exception(f"Unknown Anthropic model {self.model}")
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -60,7 +73,7 @@ class AnthropicLLM(LLM):
args["stream"] = True
args = self._transform_args(args)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -73,8 +86,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
- async for chunk in await self.async_client.completions.create(
+ args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
+ async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
@@ -88,8 +101,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
- resp = (await self.async_client.completions.create(
+ args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
+ resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 21374359..7742e8c3 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,7 @@
from functools import cached_property
import json
from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import ConfigDict
import aiohttp
from ...core.main import ChatMessage
@@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"
class GGML(LLM):
+ # this is model-specific
+ max_context_length: int = 2048
- def __init__(self, system_message: str = None):
- self.system_message = system_message
+ _client_session: aiohttp.ClientSession = None
- @cached_property
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
def name(self):
return "ggml"
@property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
@@ -33,54 +48,51 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
-
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": messages,
- **args
- }) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
- "messages": messages,
- **args
- }) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
- continue # hehe
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
- except:
- raise Exception(str(line[0]))
+ async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
- **args
- }) as resp:
- try:
- return await resp.text()
- except:
- raise Exception(await resp.text())
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 36f03270..49f593d8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from ...core.main import ChatMessage
from ..llm import LLM
import requests
@@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- api_key: str
model: str
- def __init__(self, api_key: str, model: str, system_message: str = None):
- self.api_key = api_key
+ requires_api_key: str = "HUGGING_FACE_TOKEN"
+ api_key: str = None
+
+ def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self.api_key = api_key
+
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
new file mode 100644
index 00000000..edf58fd7
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -0,0 +1,53 @@
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable
+
+from ...core.main import ChatMessage
+from . import LLM
+from .proxy_server import ProxyServer
+from .openai import OpenAI
+
+
+class MaybeProxyOpenAI(LLM):
+ model: str
+
+ requires_api_key: Optional[str] = "OPENAI_API_KEY"
+ requires_write_log: bool = True
+ requires_unique_id: bool = True
+ system_message: Union[str, None] = None
+
+ llm: Optional[LLM] = None
+
+ @property
+ def name(self):
+ return self.llm.name
+
+ @property
+ def context_length(self):
+ return self.llm.context_length
+
+ async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]):
+ if api_key is None or api_key.strip() == "":
+ self.llm = ProxyServer(model=self.model)
+ else:
+ self.llm = OpenAI(model=self.model)
+
+ await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id)
+
+ async def stop(self):
+ await self.llm.stop()
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_complete(
+ prompt, with_history=with_history, **kwargs)
+ async for item in resp:
+ yield item
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_chat(messages=messages, **kwargs)
+ async for item in resp:
+ yield item
+
+ def count_tokens(self, text: str):
+ return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index d1ca4ef9..fce6e8ab 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,48 +1,83 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
-from ...core.main import ChatMessage
+from pydantic import BaseModel
import openai
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import OpenAIServerInfo
+
+
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
+
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
+
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
class OpenAI(LLM):
- api_key: str
- default_model: str
+ model: str
+ openai_server_info: Optional[OpenAIServerInfo] = None
- def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):
- self.api_key = api_key
- self.default_model = default_model
- self.system_message = system_message
- self.openai_server_info = openai_server_info
+ requires_api_key = "OPENAI_API_KEY"
+ requires_write_log = True
+
+ system_message: Optional[str] = None
+ azure_info: Optional[AzureInfo] = None
+ write_log: Optional[Callable[[str], None]] = None
+ api_key: str = None
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):
self.write_log = write_log
+ self.api_key = api_key
+ openai.api_key = self.api_key
- openai.api_key = api_key
+ if self.openai_server_info is not None:
+ openai.api_type = self.openai_server_info.api_type
+ if self.openai_server_info.api_base is not None:
+ openai.api_base = self.openai_server_info.api_base
+ if self.openai_server_info.api_version is not None:
+ openai.api_version = self.openai_server_info.api_version
- # Using an Azure OpenAI deployment
- if openai_server_info is not None:
- openai.api_type = openai_server_info.api_type
- if openai_server_info.api_base is not None:
- openai.api_base = openai_server_info.api_base
- if openai_server_info.api_version is not None:
- openai.api_version = openai_server_info.api_version
+ async def stop(self):
+ pass
- @cached_property
+ @property
def name(self):
- return self.default_model
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.model}
if self.openai_server_info is not None:
args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -51,7 +86,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -78,12 +113,13 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
+ # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
+ args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -100,7 +136,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
messages=messages,
@@ -109,7 +145,7 @@ class OpenAI(LLM):
self.write_log(f"Completion: \n\n{resp}")
else:
prompt = prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"])
+ args["model"], self.context_length, prompt, args["max_tokens"])
self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
prompt=prompt,
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index f9e3fa01..1a48f213 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,7 +1,6 @@
-
import json
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
class ProxyServer(LLM):
- unique_id: str
- name: str
- default_model: Literal["gpt-3.5-turbo", "gpt-4"]
- write_log: Callable[[str], None]
+ model: str
+ system_message: Optional[str]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
- self.unique_id = unique_id
- self.default_model = default_model
- self.system_message = system_message
- self.name = default_model
+ unique_id: str = None
+ write_log: Callable[[str], None] = None
+ _client_session: aiohttp.ClientSession
+
+ requires_unique_id = True
+ requires_write_log = True
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context))
self.write_log = write_log
+ self.unique_id = unique_id
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def get_headers(self):
# headers with unique id
@@ -45,75 +69,72 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- if resp.status != 200:
- raise Exception(await resp.text())
-
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
+ async with self._client_session.post(f"{SERVER_URL}/complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ if resp.status != 200:
+ raise Exception(await resp.text())
+
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_chat", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- json_chunk = "{}" if json_chunk == "" else json_chunk
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- loaded_chunk = json.loads(chunk)
- yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
- except Exception as e:
- posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
- "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
- else:
- break
-
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
+ except Exception as e:
+ posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
+ "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
+ else:
+ break
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_any():
- if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
+ except:
+ raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
deleted file mode 100644
index 76240d4e..00000000
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from transformers import GPT2TokenizerFast
-
-gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
-def count_tokens(text: str) -> int:
- return len(gpt2_tokenizer.encode(text))
-
-prices = {
- # All prices are per 1k tokens
- "fine-tune-train": {
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
- },
- "completion": {
- "davinci": 0.02,
- "curie": 0.002,
- "babbage": 0.0005,
- "ada": 0.0004,
- },
- "fine-tune-completion": {
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
- },
- "embedding": {
- "ada": 0.0004
- }
-}
-
-def get_price(text: str, model: str="davinci", task: str="completion") -> float:
- return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c58ae499..6add7b1a 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -2,43 +2,47 @@ import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
from .templating import render_templated_string
+from ...libs.llm import LLM
import tiktoken
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192,
- "ggml": 2048,
- "claude-2": 100000
-}
-CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
-}
DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0}
-def encoding_for_model(model: str):
- return tiktoken.encoding_for_model(aliases.get(model, model))
+def encoding_for_model(model_name: str):
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
-def count_tokens(model: str, text: Union[str, None]):
+def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
- encoding = encoding_for_model(model)
+ encoding = encoding_for_model(model_name)
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(
- model, DEFAULT_MAX_TOKENS) - tokens_for_completion
- encoding = encoding_for_model(model)
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt
@@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
-def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
- # Doing simpler, safer version of what is here:
- # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- # every message follows <|start|>{role/name}\n{content}<|end|>\n
- TOKENS_PER_MESSAGE = 4
- return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
-
-
-def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_chat_message_tokens(model, message)
+ sum(count_chat_message_tokens(model_name, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
i = 0
- while total_tokens > max_tokens and i < len(chat_history) - 5:
+ while total_tokens > context_length and i < len(chat_history) - 5:
message = chat_history[0]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
+ while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 3. Truncate message in the last 5, except last 1
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1:
+ while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:
message = chat_history[i]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 4. Remove entire messages in the last 5, except last 1
- while total_tokens > max_tokens and len(chat_history) > 1:
+ while total_tokens > context_length and len(chat_history) > 1:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 5. Truncate last message
- if total_tokens > max_tokens and len(chat_history) > 0:
+ if total_tokens > context_length and len(chat_history) > 0:
message = chat_history[0]
message.content = prune_raw_prompt_from_top(
- model, message.content, tokens_for_completion)
- total_tokens = max_tokens
+ model_name, context_length, message.content, tokens_for_completion)
+ total_tokens = context_length
return chat_history
@@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
@@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_
function_tokens = 0
if functions is not None:
for function in functions:
- function_tokens += count_tokens(model, json.dumps(function))
+ function_tokens += count_tokens(model_name, json.dumps(function))
msgs_copy = prune_chat_history(
- model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
for msg in msgs_copy]
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 433e309e..872f8d62 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -27,7 +27,7 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = (await sdk.models.gpt35.complete(
+ source_name = (await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
@@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{file_content}
```
@@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.medium.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -151,7 +151,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 6ef5ffd6..c66cd629 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = await sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md
index 12073835..3f2f804c 100644
--- a/continuedev/src/continuedev/plugins/steps/README.md
+++ b/continuedev/src/continuedev/plugins/steps/README.md
@@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
Here's an example:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index f72a8ec0..455d5a13 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
+from ...libs.llm.openai import OpenAI
import openai
import os
from dotenv import load_dotenv
@@ -41,7 +42,7 @@ class SimpleChatStep(Step):
self.description += chunk["content"]
await sdk.update_ui()
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
@@ -166,7 +167,10 @@ class ChatWithFunctions(Step):
msg_content = ""
msg_step = None
- async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
+ gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt350613)
+
+ async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
if sdk.current_step_was_deleted():
return
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index dbe8363e..658cc7f3 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = await sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index de7cf3ac..4476c7ae 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -11,11 +11,12 @@ from pydantic import validator
from ....libs.llm.ggml import GGML
from ....models.main import Range
+from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ....core.observation import Observation, TextObservation, UserInputObservation
-from ....core.main import ChatMessage, ContinueCustomException, Step
-from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
+from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
+from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
+from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
@@ -97,7 +98,7 @@ class ShellCommandsStep(Step):
return f"Error when running shell commands:\n```\n{self._err_text}\n```"
cmds_str = "\n".join(self.cmds)
- return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
+ return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
@@ -105,7 +106,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):
else:
changes = '\n'.join(difflib.ndiff(
self._previous_contents.splitlines(), self._new_contents.splitlines()))
- description = await models.gpt3516k.complete(dedent(f"""\
+ description = await models.medium.complete(dedent(f"""\
Diff summary: "{self.user_input}"
```diff
@@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):
```
Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
- name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
+ name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
self.name = remove_quotes_and_escapes(name)
return f"{remove_quotes_and_escapes(description)}"
@@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
- model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
+ max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
- if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
- model_to_use = sdk.models.gpt3516k
+ if total_tokens > model_to_use.context_length:
+ model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(model_to_use)
# Remove tokens from the end first, and then the start to clear space
# This part finds the start and end lines
@@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst) - 1
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_end_line > min_end_line:
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_start_line])
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
# Now use the found start/end lines to get the prefix and suffix strings
diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py
index a76d491b..c38f54dc 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/migration.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 6997a547..ec670999 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -56,7 +56,7 @@ class HelpStep(Step):
summary="Help"
))
messages = await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index b54d394a..3d8d96fb 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index f28d9660..d2d6f4dd 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+ return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
range_in_files = await sdk.get_code_context(only_editing=True)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 8b2e7c2e..da6acdbf 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = (await sdk.models.gpt35.complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 522a84a3..456dba84 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.medium.complete(dedent(f"""\
This is the user request:
{self.user_request}