diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-17 14:54:36 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-17 14:54:36 -0700 | 
| commit | 1c9034cddeab0c131babe741e9145cc276bd7521 (patch) | |
| tree | abf8a563f042335caa5df94dcd951e57964d4d4c /continuedev/src | |
| parent | 05d665e65aaef62254a4da9a7a381f9984ff0db5 (diff) | |
| download | sncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.tar.gz sncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.tar.bz2 sncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.zip | |
anthropic support
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 50 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 4 | 
3 files changed, 39 insertions, 21 deletions
| diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d3501f08..280fefa8 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,7 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi  from ..models.filesystem import RangeInFile  from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import Anthropic +from ..libs.llm.anthropic import AnthropicLLM  from ..libs.llm.ggml import GGML  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer @@ -66,9 +66,9 @@ class Models:          api_key = self.provider_keys["hf_inference_api"]          return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) -    def __load_anthropic_model(self, model: str) -> Anthropic: +    def __load_anthropic_model(self, model: str) -> AnthropicLLM:          api_key = self.provider_keys["anthropic"] -        return Anthropic(api_key=api_key, model=model) +        return AnthropicLLM(api_key, model, self.system_message)      @cached_property      def claude2(self): diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 2b8831f0..566f7150 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -3,7 +3,7 @@ from functools import cached_property  import time  from typing import Any, Coroutine, Dict, Generator, List, Union  from ...core.main import ChatMessage -from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT +from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM  from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top @@ -11,14 +11,14 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_  class AnthropicLLM(LLM):      api_key: str      default_model: str -    anthropic: Anthropic +    async_client: AsyncAnthropic      def __init__(self, api_key: str, default_model: str, system_message: str = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message -        self.anthropic = Anthropic(api_key) +        self.async_client = AsyncAnthropic(api_key=api_key)      @cached_property      def name(self): @@ -28,24 +28,39 @@ class AnthropicLLM(LLM):      def default_args(self):          return {**DEFAULT_ARGS, "model": self.default_model} +    def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: +        args = args.copy() +        if "max_tokens" in args: +            args["max_tokens_to_sample"] = args["max_tokens"] +            del args["max_tokens"] +        if "frequency_penalty" in args: +            del args["frequency_penalty"] +        if "presence_penalty" in args: +            del args["presence_penalty"] +        return args +      def count_tokens(self, text: str):          return count_tokens(self.default_model, text) -    def __messages_to_prompt(self, messages: List[ChatMessage]) -> str: +    def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:          prompt = "" + +        # Anthropic prompt must start with a Human turn +        if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": +            prompt += f"{HUMAN_PROMPT} Hello."          for msg in messages: -            prompt += f"{HUMAN_PROMPT if msg.role == 'user' else AI_PROMPT} {msg.content} " +            prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " +        prompt += AI_PROMPT          return prompt      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True +        args = self._transform_args(args) -        async for chunk in await self.anthropic.completions.create( -            model=args["model"], -            max_tokens_to_sample=args["max_tokens"], +        async for chunk in await self.async_client.completions.create(              prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",              **args          ): @@ -55,25 +70,26 @@ class AnthropicLLM(LLM):          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True +        args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, args["max_tokens"], functions=args.get("functions", None)) -        async for chunk in await self.anthropic.completions.create( -            model=args["model"], -            max_tokens_to_sample=args["max_tokens"], +            args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None)) +        async for chunk in await self.async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          ): -            yield chunk.completion +            yield { +                "role": "assistant", +                "content": chunk.completion +            }      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} +        args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], with_history, args["max_tokens"], prompt, functions=None) -        resp = (await self.anthropic.completions.create( -            model=args["model"], -            max_tokens_to_sample=args["max_tokens"], +            args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None) +        resp = (await self.async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          )).completion diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 1ca98fe6..1d5d6729 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -6,6 +6,7 @@ import tiktoken  aliases = {      "ggml": "gpt-3.5-turbo", +    "claude-2": "gpt-3.5-turbo",  }  DEFAULT_MAX_TOKENS = 2048  MAX_TOKENS_FOR_MODEL = { @@ -13,7 +14,8 @@ MAX_TOKENS_FOR_MODEL = {      "gpt-3.5-turbo-0613": 4096,      "gpt-3.5-turbo-16k": 16384,      "gpt-4": 8192, -    "ggml": 2048 +    "ggml": 2048, +    "claude-2": 100000  }  CHAT_MODELS = {      "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" | 
