From 062b0872797fb4734ed36ea3a14f653dc685a86a Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 00:21:56 -0700 Subject: Anthropic support --- continuedev/src/continuedev/core/config.py | 2 +- continuedev/src/continuedev/core/sdk.py | 16 ++++- continuedev/src/continuedev/libs/llm/anthropic.py | 81 +++++++++++++++++++++++ continuedev/src/continuedev/steps/chat.py | 2 +- 4 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 continuedev/src/continuedev/libs/llm/anthropic.py (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6e430c04..05ba48c6 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -76,7 +76,7 @@ class ContinueConfig(BaseModel): server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4"] = 'gpt-4' + "gpt-4", "claude-2"] = 'gpt-4' custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d73561d2..28487600 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import Anthropic from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, Step, ChatMessage @@ -26,7 +27,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] MODEL_PROVIDER_TO_ENV_VAR = { "openai": "OPENAI_API_KEY", "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY" + "anthropic": "ANTHROPIC_API_KEY", } @@ -40,6 +41,9 @@ class Models: @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + if sdk.config.default_model == "claude-2": + with_providers.append("anthropic") + models = Models(sdk, with_providers) for provider in with_providers: if provider in MODEL_PROVIDER_TO_ENV_VAR: @@ -59,6 +63,14 @@ class Models: api_key = self.provider_keys["hf_inference_api"] return HuggingFaceInferenceAPI(api_key=api_key, model=model) + def __load_anthropic_model(self, model: str) -> Anthropic: + api_key = self.provider_keys["anthropic"] + return Anthropic(api_key=api_key, model=model) + + @cached_property + def claude2(self): + return self.__load_anthropic_model("claude-2") + @cached_property def starcoder(self): return self.__load_hf_inference_api_model("bigcode/starcoder") @@ -88,6 +100,8 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "claude-2": + return self.claude2 else: raise Exception(f"Unknown model {model_name}") diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..2b8831f0 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -0,0 +1,81 @@ + +from functools import cached_property +import time +from typing import Any, Coroutine, Dict, Generator, List, Union +from ...core.main import ChatMessage +from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top + + +class AnthropicLLM(LLM): + api_key: str + default_model: str + anthropic: Anthropic + + def __init__(self, api_key: str, default_model: str, system_message: str = None): + self.api_key = api_key + self.default_model = default_model + self.system_message = system_message + + self.anthropic = Anthropic(api_key) + + @cached_property + def name(self): + return self.default_model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.default_model} + + def count_tokens(self, text: str): + return count_tokens(self.default_model, text) + + def __messages_to_prompt(self, messages: List[ChatMessage]) -> str: + prompt = "" + for msg in messages: + prompt += f"{HUMAN_PROMPT if msg.role == 'user' else AI_PROMPT} {msg.content} " + + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + async for chunk in await self.anthropic.completions.create( + model=args["model"], + max_tokens_to_sample=args["max_tokens"], + prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", + **args + ): + yield chunk.completion + + async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + messages = compile_chat_messages( + args["model"], messages, args["max_tokens"], functions=args.get("functions", None)) + async for chunk in await self.anthropic.completions.create( + model=args["model"], + max_tokens_to_sample=args["max_tokens"], + prompt=self.__messages_to_prompt(messages), + **args + ): + yield chunk.completion + + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None) + resp = (await self.anthropic.completions.create( + model=args["model"], + max_tokens_to_sample=args["max_tokens"], + prompt=self.__messages_to_prompt(messages), + **args + )).completion + + return resp diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 14a1cd41..3751dec2 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -28,7 +28,7 @@ class SimpleChatStep(Step): completion = "" messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) + generator = sdk.models.default.stream_chat(messages, temperature=0.5) try: async for chunk in generator: if sdk.current_step_was_deleted(): -- cgit v1.2.3-70-g09d2 From a0d5cf94c8bef5726dd23711ddbef58813f42fc2 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 17 Jul 2023 14:54:36 -0700 Subject: anthropic support --- continuedev/src/continuedev/core/sdk.py | 6 +-- continuedev/src/continuedev/libs/llm/anthropic.py | 50 ++++++++++++++-------- .../src/continuedev/libs/util/count_tokens.py | 4 +- .../react-app/src/components/StepContainer.tsx | 1 - 4 files changed, 39 insertions(+), 22 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d3501f08..280fefa8 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,7 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import Anthropic +from ..libs.llm.anthropic import AnthropicLLM from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer @@ -66,9 +66,9 @@ class Models: api_key = self.provider_keys["hf_inference_api"] return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) - def __load_anthropic_model(self, model: str) -> Anthropic: + def __load_anthropic_model(self, model: str) -> AnthropicLLM: api_key = self.provider_keys["anthropic"] - return Anthropic(api_key=api_key, model=model) + return AnthropicLLM(api_key, model, self.system_message) @cached_property def claude2(self): diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 2b8831f0..566f7150 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -3,7 +3,7 @@ from functools import cached_property import time from typing import Any, Coroutine, Dict, Generator, List, Union from ...core.main import ChatMessage -from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT +from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top @@ -11,14 +11,14 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): api_key: str default_model: str - anthropic: Anthropic + async_client: AsyncAnthropic def __init__(self, api_key: str, default_model: str, system_message: str = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message - self.anthropic = Anthropic(api_key) + self.async_client = AsyncAnthropic(api_key=api_key) @cached_property def name(self): @@ -28,24 +28,39 @@ class AnthropicLLM(LLM): def default_args(self): return {**DEFAULT_ARGS, "model": self.default_model} + def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: + args = args.copy() + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + def count_tokens(self, text: str): return count_tokens(self.default_model, text) - def __messages_to_prompt(self, messages: List[ChatMessage]) -> str: + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" + + # Anthropic prompt must start with a Human turn + if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": + prompt += f"{HUMAN_PROMPT} Hello." for msg in messages: - prompt += f"{HUMAN_PROMPT if msg.role == 'user' else AI_PROMPT} {msg.content} " + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + prompt += AI_PROMPT return prompt async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True + args = self._transform_args(args) - async for chunk in await self.anthropic.completions.create( - model=args["model"], - max_tokens_to_sample=args["max_tokens"], + async for chunk in await self.async_client.completions.create( prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", **args ): @@ -55,25 +70,26 @@ class AnthropicLLM(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True + args = self._transform_args(args) messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], functions=args.get("functions", None)) - async for chunk in await self.anthropic.completions.create( - model=args["model"], - max_tokens_to_sample=args["max_tokens"], + args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None)) + async for chunk in await self.async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args ): - yield chunk.completion + yield { + "role": "assistant", + "content": chunk.completion + } async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} + args = self._transform_args(args) messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None) - resp = (await self.anthropic.completions.create( - model=args["model"], - max_tokens_to_sample=args["max_tokens"], + args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None) + resp = (await self.async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args )).completion diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 1ca98fe6..1d5d6729 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -6,6 +6,7 @@ import tiktoken aliases = { "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { @@ -13,7 +14,8 @@ MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, "gpt-4": 8192, - "ggml": 2048 + "ggml": 2048, + "claude-2": 100000 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx index 9ab7430c..93b90f0d 100644 --- a/extension/react-app/src/components/StepContainer.tsx +++ b/extension/react-app/src/components/StepContainer.tsx @@ -39,7 +39,6 @@ interface StepContainerProps { const MainDiv = styled.div<{ stepDepth: number; inFuture: boolean }>` opacity: ${(props) => (props.inFuture ? 0.3 : 1)}; animation: ${appear} 0.3s ease-in-out; - /* padding-left: ${(props) => props.stepDepth * 20}px; */ overflow: hidden; margin-left: 0px; margin-right: 0px; -- cgit v1.2.3-70-g09d2