diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 16 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 81 | 
3 files changed, 97 insertions, 2 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 98615c64..6af0878d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -81,7 +81,7 @@ class ContinueConfig(BaseModel):      disallowed_steps: Optional[List[str]] = []      allow_anonymous_telemetry: Optional[bool] = True      default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4", "ggml"] = 'gpt-4' +                           "gpt-4", "claude-2", "ggml"] = 'gpt-4'      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test",          description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 7e612d3b..d3501f08 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi  from ..models.filesystem import RangeInFile  from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import Anthropic  from ..libs.llm.ggml import GGML  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer @@ -27,7 +28,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]  MODEL_PROVIDER_TO_ENV_VAR = {      "openai": "OPENAI_API_KEY",      "hf_inference_api": "HUGGING_FACE_TOKEN", -    "anthropic": "ANTHROPIC_API_KEY" +    "anthropic": "ANTHROPIC_API_KEY",  } @@ -43,6 +44,9 @@ class Models:      @classmethod      async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": +        if sdk.config.default_model == "claude-2": +            with_providers.append("anthropic") +          models = Models(sdk, with_providers)          for provider in with_providers:              if provider in MODEL_PROVIDER_TO_ENV_VAR: @@ -62,6 +66,14 @@ class Models:          api_key = self.provider_keys["hf_inference_api"]          return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) +    def __load_anthropic_model(self, model: str) -> Anthropic: +        api_key = self.provider_keys["anthropic"] +        return Anthropic(api_key=api_key, model=model) + +    @cached_property +    def claude2(self): +        return self.__load_anthropic_model("claude-2") +      @cached_property      def starcoder(self):          return self.__load_hf_inference_api_model("bigcode/starcoder") @@ -95,6 +107,8 @@ class Models:              return self.gpt3516k          elif model_name == "gpt-4":              return self.gpt4 +        elif model_name == "claude-2": +            return self.claude2          elif model_name == "ggml":              return self.ggml          else: diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..2b8831f0 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -0,0 +1,81 @@ + +from functools import cached_property +import time +from typing import Any, Coroutine, Dict, Generator, List, Union +from ...core.main import ChatMessage +from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top + + +class AnthropicLLM(LLM): +    api_key: str +    default_model: str +    anthropic: Anthropic + +    def __init__(self, api_key: str, default_model: str, system_message: str = None): +        self.api_key = api_key +        self.default_model = default_model +        self.system_message = system_message + +        self.anthropic = Anthropic(api_key) + +    @cached_property +    def name(self): +        return self.default_model + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.default_model} + +    def count_tokens(self, text: str): +        return count_tokens(self.default_model, text) + +    def __messages_to_prompt(self, messages: List[ChatMessage]) -> str: +        prompt = "" +        for msg in messages: +            prompt += f"{HUMAN_PROMPT if msg.role == 'user' else AI_PROMPT} {msg.content} " + +        return prompt + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream"] = True + +        async for chunk in await self.anthropic.completions.create( +            model=args["model"], +            max_tokens_to_sample=args["max_tokens"], +            prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", +            **args +        ): +            yield chunk.completion + +    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream"] = True + +        messages = compile_chat_messages( +            args["model"], messages, args["max_tokens"], functions=args.get("functions", None)) +        async for chunk in await self.anthropic.completions.create( +            model=args["model"], +            max_tokens_to_sample=args["max_tokens"], +            prompt=self.__messages_to_prompt(messages), +            **args +        ): +            yield chunk.completion + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +        args = {**self.default_args, **kwargs} + +        messages = compile_chat_messages( +            args["model"], with_history, args["max_tokens"], prompt, functions=None) +        resp = (await self.anthropic.completions.create( +            model=args["model"], +            max_tokens_to_sample=args["max_tokens"], +            prompt=self.__messages_to_prompt(messages), +            **args +        )).completion + +        return resp | 
