from functools import cached_property import time from typing import Any, Coroutine, Dict, Generator, List, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top class AnthropicLLM(LLM): api_key: str default_model: str async_client: AsyncAnthropic def __init__(self, api_key: str, default_model: str, system_message: str = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message self.async_client = AsyncAnthropic(api_key=api_key) @cached_property def name(self): return self.default_model @property def default_args(self): return {**DEFAULT_ARGS, "model": self.default_model} def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: args = args.copy() if "max_tokens" in args: args["max_tokens_to_sample"] = args["max_tokens"] del args["max_tokens"] if "frequency_penalty" in args: del args["frequency_penalty"] if "presence_penalty" in args: del args["presence_penalty"] return args def count_tokens(self, text: str): return count_tokens(self.default_model, text) def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" # Anthropic prompt must start with a Human turn if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": prompt += f"{HUMAN_PROMPT} Hello." for msg in messages: prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " prompt += AI_PROMPT return prompt async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True args = self._transform_args(args) async for chunk in await self.async_client.completions.create( prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", **args ): yield chunk.completion async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True args = self._transform_args(args) messages = compile_chat_messages( args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) async for chunk in await self.async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args ): yield { "role": "assistant", "content": chunk.completion } async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} args = self._transform_args(args) messages = compile_chat_messages( args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) resp = (await self.async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args )).completion return resp