diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 00:21:56 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 00:21:56 -0700 |
commit | 868e0b7ef5357b89186119c3c2fa8bd427b8db30 (patch) | |
tree | 46c7cc8d6bae7fdbcbcec341e9b9e6f789f96de0 /continuedev/src/continuedev/core | |
parent | 71a869bda2018d8fcfff56f7eccfff2943c30ee0 (diff) | |
download | sncontinue-868e0b7ef5357b89186119c3c2fa8bd427b8db30.tar.gz sncontinue-868e0b7ef5357b89186119c3c2fa8bd427b8db30.tar.bz2 sncontinue-868e0b7ef5357b89186119c3c2fa8bd427b8db30.zip |
Anthropic support
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 16 |
2 files changed, 16 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6e430c04..05ba48c6 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -76,7 +76,7 @@ class ContinueConfig(BaseModel): server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4"] = 'gpt-4' + "gpt-4", "claude-2"] = 'gpt-4' custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d73561d2..28487600 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import Anthropic from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, Step, ChatMessage @@ -26,7 +27,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] MODEL_PROVIDER_TO_ENV_VAR = { "openai": "OPENAI_API_KEY", "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY" + "anthropic": "ANTHROPIC_API_KEY", } @@ -40,6 +41,9 @@ class Models: @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + if sdk.config.default_model == "claude-2": + with_providers.append("anthropic") + models = Models(sdk, with_providers) for provider in with_providers: if provider in MODEL_PROVIDER_TO_ENV_VAR: @@ -59,6 +63,14 @@ class Models: api_key = self.provider_keys["hf_inference_api"] return HuggingFaceInferenceAPI(api_key=api_key, model=model) + def __load_anthropic_model(self, model: str) -> Anthropic: + api_key = self.provider_keys["anthropic"] + return Anthropic(api_key=api_key, model=model) + + @cached_property + def claude2(self): + return self.__load_anthropic_model("claude-2") + @cached_property def starcoder(self): return self.__load_hf_inference_api_model("bigcode/starcoder") @@ -88,6 +100,8 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "claude-2": + return self.claude2 else: raise Exception(f"Unknown model {model_name}") |