1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
import json
from typing import Callable
import aiohttp
from pydantic import Field
from ...core.main import ContinueCustomException
from ..util.logging import logger
from .base import LLM
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt
class TogetherLLM(LLM):
"""
The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this:
```python title="~/.continue/config.py"
from continuedev.core.models import Models
from continuedev.libs.llm.together import TogetherLLM
config = ContinueConfig(
...
models=Models(
default=TogetherLLM(
api_key="<API_KEY>",
model="togethercomputer/llama-2-13b-chat"
)
)
)
```
"""
api_key: str = Field(..., description="Together API key")
model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
base_url: str = Field(
"https://api.together.xyz",
description="The base URL for your Together API instance",
)
_client_session: aiohttp.ClientSession = None
template_messages: Callable = llama2_template_messages
prompt_templates = {
"edit": simplified_edit_prompt,
}
async def start(self, **kwargs):
await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
timeout=aiohttp.ClientTimeout(total=self.timeout),
)
async def stop(self):
await self._client_session.close()
async def _stream_complete(self, prompt, options):
args = self.collect_args(options)
async with self._client_session.post(
f"{self.base_url}/inference",
json={
"prompt": prompt,
"stream_tokens": True,
**args,
},
headers={"Authorization": f"Bearer {self.api_key}"},
proxy=self.proxy,
) as resp:
async for line in resp.content.iter_chunks():
if line[1]:
json_chunk = line[0].decode("utf-8")
if json_chunk.startswith(": ping - ") or json_chunk.startswith(
"data: [DONE]"
):
continue
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
if chunk.startswith("data: "):
chunk = chunk[6:]
if chunk == "[DONE]":
break
try:
json_chunk = json.loads(chunk)
except Exception as e:
logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}")
continue
if "choices" in json_chunk:
yield json_chunk["choices"][0]["text"]
async def _complete(self, prompt: str, options):
args = self.collect_args(options)
async with self._client_session.post(
f"{self.base_url}/inference",
json={"prompt": prompt, **args},
headers={"Authorization": f"Bearer {self.api_key}"},
proxy=self.proxy,
) as resp:
text = await resp.text()
j = json.loads(text)
try:
if "choices" not in j["output"]:
raise Exception(text)
if "output" in j:
return j["output"]["choices"][0]["text"]
except Exception as e:
j = await resp.json()
if "error" in j:
if j["error"].startswith("invalid hexlify value"):
raise ContinueCustomException(
message=f"Invalid Together API key:\n\n{j['error']}",
title="Together API Error",
)
else:
raise ContinueCustomException(
message=j["error"], title="Together API Error"
)
raise e
|