diff options
author | Ty Dunn <ty@tydunn.com> | 2023-09-07 11:19:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 11:19:38 -0700 |
commit | 8dad79af8b18c08e270382ce1b18a3956fa59626 (patch) | |
tree | 2317e2a5e43624e54e7894e016b9a84a08d1cec9 | |
parent | 65887473f4c6711d2a64a087f835f86556c75dff (diff) | |
download | sncontinue-8dad79af8b18c08e270382ce1b18a3956fa59626.tar.gz sncontinue-8dad79af8b18c08e270382ce1b18a3956fa59626.tar.bz2 sncontinue-8dad79af8b18c08e270382ce1b18a3956fa59626.zip |
adding support for Hugging Face Inference Endpoints (#460)
* stream complete sketch
* correct structure but issues
* refactor: :art: clean up hf_inference_api.py
* fix: :bug: quick fix in hf_infrerence_api.py
* feat: :memo: update documentation code for hf_inference_api
* hf docs
* now working
---------
Co-authored-by: Nate Sesti <sestinj@gmail.com>
-rw-r--r-- | continuedev/poetry.lock | 146 | ||||
-rw-r--r-- | continuedev/pyproject.toml | 1 | ||||
-rw-r--r-- | continuedev/requirements.txt | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 86 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 1 | ||||
-rw-r--r-- | docs/docs/customization.md | 19 | ||||
-rw-r--r-- | extension/react-app/src/components/ModelSelect.tsx | 8 |
8 files changed, 212 insertions, 55 deletions
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index 6808a83e..23ff9094 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -635,6 +635,24 @@ doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdo test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] [[package]] +name = "filelock" +version = "3.12.3" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"}, + {file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] + +[[package]] name = "frozenlist" version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" @@ -705,6 +723,41 @@ files = [ ] [[package]] +name = "fsspec" +version = "2023.9.0" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.9.0-py3-none-any.whl", hash = "sha256:d55b9ab2a4c1f2b759888ae9f93e40c2aa72c0808132e87e282b549f9e6c4254"}, + {file = "fsspec-2023.9.0.tar.gz", hash = "sha256:4dbf0fefee035b7c6d3bbbe6bc99b2f201f40d4dca95b67c2b719be77bcd917f"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" @@ -760,6 +813,38 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] [[package]] +name = "huggingface-hub" +version = "0.16.4" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, + {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" @@ -1358,6 +1443,65 @@ websockets = ["websockets (>=10.3)"] yapf = ["whatthepatch (>=1.0.2,<2.0.0)", "yapf (>=0.33.0)"] [[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] name = "redbaron" version = "0.9.2" description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task" @@ -2267,4 +2411,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "d67290b217f5991c09c3073c8a97f8ffc5b51d69c7802db5f65462a029e5f99e" +content-hash = "85e02897d3dd2c8b9bce7e15985eb7aecbd52c166478fceca445787fdde2d60b" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 134f9e16..d6cadfbc 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -33,6 +33,7 @@ bs4 = "^0.0.1" replicate = "^0.11.0" redbaron = "^0.9.2" python-lsp-server = "^1.7.4" +huggingface-hub = "^0.16.4" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/requirements.txt b/continuedev/requirements.txt index 91120c59..a689f085 100644 --- a/continuedev/requirements.txt +++ b/continuedev/requirements.txt @@ -23,4 +23,5 @@ ripgrepy==2.0.0 replicate==0.11.0 bs4==0.0.1 redbaron==0.9.2 -python-lsp-server==1.2.0
\ No newline at end of file +python-lsp-server==1.2.0 +huggingface-hub==0.16.4
\ No newline at end of file diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 9816d5d9..f24c81ca 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -11,6 +11,7 @@ from ..libs.llm.ollama import Ollama from ..libs.llm.openai import OpenAI from ..libs.llm.replicate import ReplicateLLM from ..libs.llm.together import TogetherLLM +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI class ContinueSDK(BaseModel): @@ -37,6 +38,7 @@ MODEL_CLASSES = { ReplicateLLM, Ollama, LlamaCpp, + HuggingFaceInferenceAPI, ] } @@ -49,6 +51,7 @@ MODEL_MODULE_NAMES = { "ReplicateLLM": "replicate", "Ollama": "ollama", "LlamaCpp": "llamacpp", + "HuggingFaceInferenceAPI": "hf_inference_api", } diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 43aac148..b8fc49a9 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,73 +1,53 @@ -from typing import List, Optional +from typing import Callable, Dict, List +from ..llm import LLM, CompletionOptions -import aiohttp -import requests - -from ...core.main import ChatMessage -from ..llm import LLM +from huggingface_hub import InferenceClient +from .prompts.chat import llama2_template_messages from .prompts.edit import simplified_edit_prompt -DEFAULT_MAX_TIME = 120.0 - class HuggingFaceInferenceAPI(LLM): + model: str = "Hugging Face Inference API" hf_token: str - self_hosted_url: str = None - - verify_ssl: Optional[bool] = None + endpoint_url: str = None - _client_session: aiohttp.ClientSession = None + template_messages: Callable[[List[Dict[str, str]]], str] | None = llama2_template_messages prompt_templates = { - "edit": simplified_edit_prompt, + "edit": simplified_edit_prompt, } class Config: arbitrary_types_allowed = True - async def start(self, **kwargs): - await super().start(**kwargs) - self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) - async def stop(self): - await self._client_session.close() + if "max_tokens" in args: + args["max_new_tokens"] = args["max_tokens"] + del args["max_tokens"] + if "stop" in args: + args["stop_sequences"] = args["stop"] + del args["stop"] + if "model" in args: + del args["model"] + return args - async def _complete(self, prompt: str, options): - """Return the completion of the text with the given temperature.""" - API_URL = ( - self.base_url or f"https://api-inference.huggingface.co/models/{self.model}" - ) - headers = {"Authorization": f"Bearer {self.hf_token}"} - response = requests.post( - API_URL, - headers=headers, - json={ - "inputs": prompt, - "parameters": { - "max_new_tokens": min( - 250, self.max_context_length - self.count_tokens(prompt) - ), - "max_time": DEFAULT_MAX_TIME, - "return_full_text": False, - }, - }, - ) - data = response.json() - - # Error if the response is not a list - if not isinstance(data, list): - raise Exception("Hugging Face returned an error response: \n\n", data) + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) - return data[0]["generated_text"] + client = InferenceClient(self.endpoint_url, token=self.hf_token) - async def _stream_chat(self, messages: List[ChatMessage], options): - response = await self._complete(messages[-1].content, messages[:-1]) - yield {"content": response, "role": "assistant"} + stream = client.text_generation(prompt, stream=True, details=True) - async def _stream_complete(self, prompt, options): - response = await self._complete(prompt, options) - yield response + for r in stream: + # skip special tokens + if r.token.special: + continue + # stop if we encounter a stop sequence + if options.stop is not None: + if r.token.text in options.stop: + break + yield r.token.text
\ No newline at end of file diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py index 3dae39c2..7fa34907 100644 --- a/continuedev/src/continuedev/plugins/steps/setup_model.py +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -13,6 +13,7 @@ MODEL_CLASS_TO_MESSAGE = { "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.", + "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect." } diff --git a/docs/docs/customization.md b/docs/docs/customization.md index 5fc3eab5..fb7dc0c5 100644 --- a/docs/docs/customization.md +++ b/docs/docs/customization.md @@ -21,6 +21,7 @@ Open-Source Models (not local) - [TogetherLLM](#together) - Use any model from the [Together Models list](https://docs.together.ai/docs/models-inference) with your Together API key. - [ReplicateLLM](#replicate) - Use any open-source model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models) with your Replicate API key. +- [HuggingFaceInferenceAPI](#huggingface) - Use any open-source model from the [Hugging Face Inference API](https://huggingface.co/inference-api) with your Hugging Face token. ## Change the default LLM @@ -206,6 +207,24 @@ config = ContinueConfig( If you don't specify the `model` parameter, it will default to `replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781`. +### Hugging Face + +Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on “New endpoint”, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking “Create Endpoint”. Change `~/.continue/config.py` to look like this: + +```python +from continuedev.src.continuedev.core.models import Models +from continuedev.src.continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI + +config = ContinueConfig( + ... + models=Models( + default=HuggingFaceInferenceAPI( + endpoint_url: "<INFERENCE_API_ENDPOINT_URL>", + hf_token: "<HUGGING_FACE_TOKEN>", + ) +) +``` + ### Self-hosting an open-source model If you want to self-host on Colab, RunPod, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 primary methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`. diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx index e6c0e9c9..6a7692ff 100644 --- a/extension/react-app/src/components/ModelSelect.tsx +++ b/extension/react-app/src/components/ModelSelect.tsx @@ -60,6 +60,14 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [ args: {}, }, { + title: "HuggingFace Inference API", + class: "HuggingFaceInferenceAPI", + args: { + endpoint_url: "<INFERENCE_API_ENDPOINT_URL>", + hf_token: "<HUGGING_FACE_TOKEN>", + }, + }, + { title: "Other OpenAI-compatible API", class: "GGML", args: { |