summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/together.py
blob: 35b3a424439375e61737d7df13a674696a9d6c5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import json
from typing import Callable

import aiohttp
from pydantic import Field

from ...core.main import ContinueCustomException
from ..util.logging import logger
from .base import LLM
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt


class TogetherLLM(LLM):
    """
    The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this:

    ```python title="~/.continue/config.py"
    from continuedev.core.models import Models
    from continuedev.libs.llm.together import TogetherLLM

    config = ContinueConfig(
        ...
        models=Models(
            default=TogetherLLM(
                api_key="<API_KEY>",
                model="togethercomputer/llama-2-13b-chat"
            )
        )
    )
    ```
    """

    api_key: str = Field(..., description="Together API key")

    model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
    base_url: str = Field(
        "https://api.together.xyz",
        description="The base URL for your Together API instance",
    )

    _client_session: aiohttp.ClientSession = None

    template_messages: Callable = llama2_template_messages

    prompt_templates = {
        "edit": simplified_edit_prompt,
    }

    async def start(self, **kwargs):
        await super().start(**kwargs)
        self._client_session = aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
            timeout=aiohttp.ClientTimeout(total=self.timeout),
        )

    async def stop(self):
        await self._client_session.close()

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)

        async with self._client_session.post(
            f"{self.base_url}/inference",
            json={
                "prompt": prompt,
                "stream_tokens": True,
                **args,
            },
            headers={"Authorization": f"Bearer {self.api_key}"},
            proxy=self.proxy,
        ) as resp:
            async for line in resp.content.iter_chunks():
                if line[1]:
                    json_chunk = line[0].decode("utf-8")
                    if json_chunk.startswith(": ping - ") or json_chunk.startswith(
                        "data: [DONE]"
                    ):
                        continue

                    chunks = json_chunk.split("\n")
                    for chunk in chunks:
                        if chunk.strip() != "":
                            if chunk.startswith("data: "):
                                chunk = chunk[6:]
                            if chunk == "[DONE]":
                                break
                            try:
                                json_chunk = json.loads(chunk)
                            except Exception as e:
                                logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}")
                                continue
                            if "choices" in json_chunk:
                                yield json_chunk["choices"][0]["text"]

    async def _complete(self, prompt: str, options):
        args = self.collect_args(options)

        async with self._client_session.post(
            f"{self.base_url}/inference",
            json={"prompt": prompt, **args},
            headers={"Authorization": f"Bearer {self.api_key}"},
            proxy=self.proxy,
        ) as resp:
            text = await resp.text()
            j = json.loads(text)
            try:
                if "choices" not in j["output"]:
                    raise Exception(text)
                if "output" in j:
                    return j["output"]["choices"][0]["text"]
            except Exception as e:
                j = await resp.json()
                if "error" in j:
                    if j["error"].startswith("invalid hexlify value"):
                        raise ContinueCustomException(
                            message=f"Invalid Together API key:\n\n{j['error']}",
                            title="Together API Error",
                        )
                    else:
                        raise ContinueCustomException(
                            message=j["error"], title="Together API Error"
                        )

                raise e