summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-09 22:04:22 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-09 22:04:22 -0700
commit8456b24318b13ea5d5dabec2328dd854f8a492b4 (patch)
tree1eecd5118d7c31903fc15f327c9801aba5f3b32f
parente8ebff1e6b07dfaafff81ee7013bb019cbfe2075 (diff)
downloadsncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.tar.gz
sncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.tar.bz2
sncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.zip
feat: :sparkles: support for Together.ai models
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py118
-rw-r--r--extension/react-app/src/redux/slices/serverStateReducer.ts6
4 files changed, 126 insertions, 5 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 2f131354..25a61e63 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -82,7 +82,10 @@ class GGML(LLM):
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
+ yield {
+ "role": "assistant",
+ "content": json.loads(chunk[6:])["choices"][0]["delta"]
+ }
except:
raise Exception(str(line[0]))
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index 235fd906..0dd359e7 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -25,7 +25,7 @@ class ReplicateLLM(LLM):
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+ return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
def count_tokens(self, text: str):
return count_tokens(self.name, text)
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
new file mode 100644
index 00000000..1cc0a711
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -0,0 +1,118 @@
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
+
+
+class TogetherLLM(LLM):
+ # this is model-specific
+ api_key: str
+ model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
+ max_context_length: int = 2048
+ base_url: str = "https://api.together.xyz"
+ verify_ssl: bool = True
+
+ _client_session: aiohttp.ClientSession = None
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl))
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str:
+ system_message = None
+ if chat_messages[0]["role"] == "system":
+ system_message = chat_messages.pop(0)["content"]
+
+ prompt = "\n"
+ if system_message:
+ prompt += f"<human>: Hi!\n<bot>: {system_message}\n"
+ for message in chat_messages:
+ prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n'
+ return prompt
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream_tokens"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["stream_tokens"] = True
+
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield {
+ "role": "assistant",
+ "content": json.loads(chunk[6:])["choices"][0]["text"]
+ }
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ messages = compile_chat_messages(args["model"], with_history, self.context_length,
+ args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/extension/react-app/src/redux/slices/serverStateReducer.ts b/extension/react-app/src/redux/slices/serverStateReducer.ts
index bd60f1c7..a20476b2 100644
--- a/extension/react-app/src/redux/slices/serverStateReducer.ts
+++ b/extension/react-app/src/redux/slices/serverStateReducer.ts
@@ -9,9 +9,9 @@ const initialState: FullState = {
name: "Welcome to Continue",
hide: false,
description: `- Highlight code section and ask a question or give instructions
- - Use \`cmd+m\` (Mac) / \`ctrl+m\` (Windows) to open Continue
- - Use \`/help\` to ask questions about how to use Continue
- - [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.`,
+- Use \`cmd+m\` (Mac) / \`ctrl+m\` (Windows) to open Continue
+- Use \`/help\` to ask questions about how to use Continue
+- [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.`,
system_message: null,
chat_context: [],
manage_own_chat_context: false,