diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/poetry.lock | 135 | ||||
-rw-r--r-- | continuedev/pyproject.toml | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 16 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 97 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 4 |
6 files changed, 251 insertions, 4 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index 625aabc9..e8927fe7 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -125,6 +125,26 @@ files = [ frozenlist = ">=1.1.0" [[package]] +name = "anthropic" +version = "0.3.4" +description = "Client library for the anthropic API" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "anthropic-0.3.4-py3-none-any.whl", hash = "sha256:7b0396f663b0e4eaaf485ae59a0be014cddfc0f0b8f4dad79bb35d8f28439097"}, + {file = "anthropic-0.3.4.tar.gz", hash = "sha256:36184840bd33184697666d4f1ec951d78ef5da22e87d936cd3c04b611d84e93c"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<4" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<2.0.0" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.1.1,<5" + +[[package]] name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" @@ -387,6 +407,18 @@ files = [ dev = ["pytest (>=3.7)"] [[package]] +name = "distro" +version = "1.8.0" +description = "Distro - an OS platform information API" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"}, + {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, +] + +[[package]] name = "fastapi" version = "0.95.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" @@ -601,6 +633,52 @@ files = [ ] [[package]] +name = "httpcore" +version = "0.17.3" +description = "A minimal low-level HTTP client." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpcore-0.17.3-py3-none-any.whl", hash = "sha256:c2789b767ddddfa2a5782e3199b2b7f6894540b17b16ec26b2c4d8e103510b87"}, + {file = "httpcore-0.17.3.tar.gz", hash = "sha256:a6f30213335e34c1ade7be6ec7c47f19f50c56db36abef1a9dfa3815b1cb3888"}, +] + +[package.dependencies] +anyio = ">=3.0,<5.0" +certifi = "*" +h11 = ">=0.13,<0.15" +sniffio = ">=1.0.0,<2.0.0" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] +name = "httpx" +version = "0.24.1" +description = "The next generation HTTP client." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpx-0.24.1-py3-none-any.whl", hash = "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd"}, + {file = "httpx-0.24.1.tar.gz", hash = "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd"}, +] + +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.18.0" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" @@ -1578,6 +1656,61 @@ requests = ">=2.26.0" blobfile = ["blobfile (>=2)"] [[package]] +name = "tokenizers" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, + {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, + {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, + {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, + {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, + {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, + {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, + {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, + {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, + {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, + {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, + {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, +] + +[package.extras] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] + +[[package]] name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" @@ -1882,4 +2015,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "82510deb9f4afb5bc38db0dfd88ad88005fa0b6221c24e8c1700c006360f3f88" +content-hash = "87dbf6d1e56ce6ba81a01a59c0de2d3717925bac9639710bf3ff3ce30f5f5e2c" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 3077de1c..6a646cbe 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -24,6 +24,7 @@ tiktoken = "^0.4.0" jsonref = "^1.1.0" jsonschema = "^4.17.3" directory-tree = "^0.0.3.1" +anthropic = "^0.3.4" chevron = "^0.14.0" [tool.poetry.scripts] diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 98615c64..6af0878d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -81,7 +81,7 @@ class ContinueConfig(BaseModel): disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4", "ggml"] = 'gpt-4' + "gpt-4", "claude-2", "ggml"] = 'gpt-4' custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 7e612d3b..280fefa8 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import AnthropicLLM from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer @@ -27,7 +28,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] MODEL_PROVIDER_TO_ENV_VAR = { "openai": "OPENAI_API_KEY", "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY" + "anthropic": "ANTHROPIC_API_KEY", } @@ -43,6 +44,9 @@ class Models: @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + if sdk.config.default_model == "claude-2": + with_providers.append("anthropic") + models = Models(sdk, with_providers) for provider in with_providers: if provider in MODEL_PROVIDER_TO_ENV_VAR: @@ -62,6 +66,14 @@ class Models: api_key = self.provider_keys["hf_inference_api"] return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) + def __load_anthropic_model(self, model: str) -> AnthropicLLM: + api_key = self.provider_keys["anthropic"] + return AnthropicLLM(api_key, model, self.system_message) + + @cached_property + def claude2(self): + return self.__load_anthropic_model("claude-2") + @cached_property def starcoder(self): return self.__load_hf_inference_api_model("bigcode/starcoder") @@ -95,6 +107,8 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "claude-2": + return self.claude2 elif model_name == "ggml": return self.ggml else: diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..566f7150 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -0,0 +1,97 @@ + +from functools import cached_property +import time +from typing import Any, Coroutine, Dict, Generator, List, Union +from ...core.main import ChatMessage +from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top + + +class AnthropicLLM(LLM): + api_key: str + default_model: str + async_client: AsyncAnthropic + + def __init__(self, api_key: str, default_model: str, system_message: str = None): + self.api_key = api_key + self.default_model = default_model + self.system_message = system_message + + self.async_client = AsyncAnthropic(api_key=api_key) + + @cached_property + def name(self): + return self.default_model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.default_model} + + def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: + args = args.copy() + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + + def count_tokens(self, text: str): + return count_tokens(self.default_model, text) + + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + async for chunk in await self.async_client.completions.create( + prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", + **args + ): + yield chunk.completion + + async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None)) + async for chunk in await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + ): + yield { + "role": "assistant", + "content": chunk.completion + } + + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None) + resp = (await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + )).completion + + return resp diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 1ca98fe6..1d5d6729 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -6,6 +6,7 @@ import tiktoken aliases = { "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { @@ -13,7 +14,8 @@ MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, "gpt-4": 8192, - "ggml": 2048 + "ggml": 2048, + "claude-2": 100000 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" |