summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/config.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py16
2 files changed, 16 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 98615c64..6af0878d 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -81,7 +81,7 @@ class ContinueConfig(BaseModel):
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "ggml"] = 'gpt-4'
+ "gpt-4", "claude-2", "ggml"] = 'gpt-4'
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
description="This is an example custom command. Use /config to edit it and create more",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 7e612d3b..280fefa8 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi
from ..models.filesystem import RangeInFile
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
+from ..libs.llm.anthropic import AnthropicLLM
from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -27,7 +28,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
MODEL_PROVIDER_TO_ENV_VAR = {
"openai": "OPENAI_API_KEY",
"hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY"
+ "anthropic": "ANTHROPIC_API_KEY",
}
@@ -43,6 +44,9 @@ class Models:
@classmethod
async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
+ if sdk.config.default_model == "claude-2":
+ with_providers.append("anthropic")
+
models = Models(sdk, with_providers)
for provider in with_providers:
if provider in MODEL_PROVIDER_TO_ENV_VAR:
@@ -62,6 +66,14 @@ class Models:
api_key = self.provider_keys["hf_inference_api"]
return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
+ def __load_anthropic_model(self, model: str) -> AnthropicLLM:
+ api_key = self.provider_keys["anthropic"]
+ return AnthropicLLM(api_key, model, self.system_message)
+
+ @cached_property
+ def claude2(self):
+ return self.__load_anthropic_model("claude-2")
+
@cached_property
def starcoder(self):
return self.__load_hf_inference_api_model("bigcode/starcoder")
@@ -95,6 +107,8 @@ class Models:
return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
+ elif model_name == "claude-2":
+ return self.claude2
elif model_name == "ggml":
return self.ggml
else: