summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorTy Dunn <ty@tydunn.com>2023-07-17 19:25:32 -0500
committerTy Dunn <ty@tydunn.com>2023-07-17 19:25:32 -0500
commit0ec688123abb8dd63fe5ae8c191a69746829884a (patch)
treed98a78480e3a34b43f67cd1fffd5d5d5fdd09406 /continuedev/src
parent043bdade5605a1c509a9f1927ebbe54db7d900f4 (diff)
parent8498ab7fd2945703f4ad59dabf51cb851db4f64d (diff)
downloadsncontinue-0ec688123abb8dd63fe5ae8c191a69746829884a.tar.gz
sncontinue-0ec688123abb8dd63fe5ae8c191a69746829884a.tar.bz2
sncontinue-0ec688123abb8dd63fe5ae8c191a69746829884a.zip
Merge branch 'main' of github.com:continuedev/continue
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/config.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py16
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py97
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py4
4 files changed, 116 insertions, 3 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 98615c64..6af0878d 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -81,7 +81,7 @@ class ContinueConfig(BaseModel):
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "ggml"] = 'gpt-4'
+ "gpt-4", "claude-2", "ggml"] = 'gpt-4'
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
description="This is an example custom command. Use /config to edit it and create more",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 7e612d3b..280fefa8 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi
from ..models.filesystem import RangeInFile
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
+from ..libs.llm.anthropic import AnthropicLLM
from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -27,7 +28,7 @@ ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
MODEL_PROVIDER_TO_ENV_VAR = {
"openai": "OPENAI_API_KEY",
"hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY"
+ "anthropic": "ANTHROPIC_API_KEY",
}
@@ -43,6 +44,9 @@ class Models:
@classmethod
async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
+ if sdk.config.default_model == "claude-2":
+ with_providers.append("anthropic")
+
models = Models(sdk, with_providers)
for provider in with_providers:
if provider in MODEL_PROVIDER_TO_ENV_VAR:
@@ -62,6 +66,14 @@ class Models:
api_key = self.provider_keys["hf_inference_api"]
return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
+ def __load_anthropic_model(self, model: str) -> AnthropicLLM:
+ api_key = self.provider_keys["anthropic"]
+ return AnthropicLLM(api_key, model, self.system_message)
+
+ @cached_property
+ def claude2(self):
+ return self.__load_anthropic_model("claude-2")
+
@cached_property
def starcoder(self):
return self.__load_hf_inference_api_model("bigcode/starcoder")
@@ -95,6 +107,8 @@ class Models:
return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
+ elif model_name == "claude-2":
+ return self.claude2
elif model_name == "ggml":
return self.ggml
else:
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
new file mode 100644
index 00000000..566f7150
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -0,0 +1,97 @@
+
+from functools import cached_property
+import time
+from typing import Any, Coroutine, Dict, Generator, List, Union
+from ...core.main import ChatMessage
+from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
+from ..llm import LLM
+from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+
+
+class AnthropicLLM(LLM):
+ api_key: str
+ default_model: str
+ async_client: AsyncAnthropic
+
+ def __init__(self, api_key: str, default_model: str, system_message: str = None):
+ self.api_key = api_key
+ self.default_model = default_model
+ self.system_message = system_message
+
+ self.async_client = AsyncAnthropic(api_key=api_key)
+
+ @cached_property
+ def name(self):
+ return self.default_model
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.default_model}
+
+ def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
+ args = args.copy()
+ if "max_tokens" in args:
+ args["max_tokens_to_sample"] = args["max_tokens"]
+ del args["max_tokens"]
+ if "frequency_penalty" in args:
+ del args["frequency_penalty"]
+ if "presence_penalty" in args:
+ del args["presence_penalty"]
+ return args
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.default_model, text)
+
+ def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
+ prompt = ""
+
+ # Anthropic prompt must start with a Human turn
+ if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system":
+ prompt += f"{HUMAN_PROMPT} Hello."
+ for msg in messages:
+ prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} "
+
+ prompt += AI_PROMPT
+ return prompt
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+ args = self._transform_args(args)
+
+ async for chunk in await self.async_client.completions.create(
+ prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
+ **args
+ ):
+ yield chunk.completion
+
+ async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+ args = self._transform_args(args)
+
+ messages = compile_chat_messages(
+ args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None))
+ async for chunk in await self.async_client.completions.create(
+ prompt=self.__messages_to_prompt(messages),
+ **args
+ ):
+ yield {
+ "role": "assistant",
+ "content": chunk.completion
+ }
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+ args = self._transform_args(args)
+
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None)
+ resp = (await self.async_client.completions.create(
+ prompt=self.__messages_to_prompt(messages),
+ **args
+ )).completion
+
+ return resp
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 1ca98fe6..1d5d6729 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -6,6 +6,7 @@ import tiktoken
aliases = {
"ggml": "gpt-3.5-turbo",
+ "claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
@@ -13,7 +14,8 @@ MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
- "ggml": 2048
+ "ggml": 2048,
+ "claude-2": 100000
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"