summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-17 14:54:36 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-17 14:54:36 -0700
commit1c9034cddeab0c131babe741e9145cc276bd7521 (patch)
treeabf8a563f042335caa5df94dcd951e57964d4d4c /continuedev
parent05d665e65aaef62254a4da9a7a381f9984ff0db5 (diff)
downloadsncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.tar.gz
sncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.tar.bz2
sncontinue-1c9034cddeab0c131babe741e9145cc276bd7521.zip
anthropic support
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/sdk.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py50
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
3 files changed, 39 insertions, 21 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index d3501f08..280fefa8 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -11,7 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi
from ..models.filesystem import RangeInFile
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import Anthropic
+from ..libs.llm.anthropic import AnthropicLLM
from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -66,9 +66,9 @@ class Models:
api_key = self.provider_keys["hf_inference_api"]
return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
- def __load_anthropic_model(self, model: str) -> Anthropic:
+ def __load_anthropic_model(self, model: str) -> AnthropicLLM:
api_key = self.provider_keys["anthropic"]
- return Anthropic(api_key=api_key, model=model)
+ return AnthropicLLM(api_key, model, self.system_message)
@cached_property
def claude2(self):
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 2b8831f0..566f7150 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -3,7 +3,7 @@ from functools import cached_property
import time
from typing import Any, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
-from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
+from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
@@ -11,14 +11,14 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_
class AnthropicLLM(LLM):
api_key: str
default_model: str
- anthropic: Anthropic
+ async_client: AsyncAnthropic
def __init__(self, api_key: str, default_model: str, system_message: str = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
- self.anthropic = Anthropic(api_key)
+ self.async_client = AsyncAnthropic(api_key=api_key)
@cached_property
def name(self):
@@ -28,24 +28,39 @@ class AnthropicLLM(LLM):
def default_args(self):
return {**DEFAULT_ARGS, "model": self.default_model}
+ def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
+ args = args.copy()
+ if "max_tokens" in args:
+ args["max_tokens_to_sample"] = args["max_tokens"]
+ del args["max_tokens"]
+ if "frequency_penalty" in args:
+ del args["frequency_penalty"]
+ if "presence_penalty" in args:
+ del args["presence_penalty"]
+ return args
+
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
- def __messages_to_prompt(self, messages: List[ChatMessage]) -> str:
+ def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
+
+ # Anthropic prompt must start with a Human turn
+ if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system":
+ prompt += f"{HUMAN_PROMPT} Hello."
for msg in messages:
- prompt += f"{HUMAN_PROMPT if msg.role == 'user' else AI_PROMPT} {msg.content} "
+ prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} "
+ prompt += AI_PROMPT
return prompt
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
+ args = self._transform_args(args)
- async for chunk in await self.anthropic.completions.create(
- model=args["model"],
- max_tokens_to_sample=args["max_tokens"],
+ async for chunk in await self.async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -55,25 +70,26 @@ class AnthropicLLM(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
+ args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None))
- async for chunk in await self.anthropic.completions.create(
- model=args["model"],
- max_tokens_to_sample=args["max_tokens"],
+ args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None))
+ async for chunk in await self.async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
- yield chunk.completion
+ yield {
+ "role": "assistant",
+ "content": chunk.completion
+ }
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
+ args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None)
- resp = (await self.anthropic.completions.create(
- model=args["model"],
- max_tokens_to_sample=args["max_tokens"],
+ args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None)
+ resp = (await self.async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 1ca98fe6..1d5d6729 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -6,6 +6,7 @@ import tiktoken
aliases = {
"ggml": "gpt-3.5-turbo",
+ "claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
@@ -13,7 +14,8 @@ MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
- "ggml": 2048
+ "ggml": 2048,
+ "claude-2": 100000
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"