diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-09 22:04:22 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-09 22:04:22 -0700 | 
| commit | 8456b24318b13ea5d5dabec2328dd854f8a492b4 (patch) | |
| tree | 1eecd5118d7c31903fc15f327c9801aba5f3b32f /continuedev | |
| parent | e8ebff1e6b07dfaafff81ee7013bb019cbfe2075 (diff) | |
| download | sncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.tar.gz sncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.tar.bz2 sncontinue-8456b24318b13ea5d5dabec2328dd854f8a492b4.zip  | |
feat: :sparkles: support for Together.ai models
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/replicate.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 118 | 
3 files changed, 123 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 2f131354..25a61e63 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -82,7 +82,10 @@ class GGML(LLM):                          chunks = json_chunk.split("\n")                          for chunk in chunks:                              if chunk.strip() != "": -                                yield json.loads(chunk[6:])["choices"][0]["delta"] +                                yield { +                                    "role": "assistant", +                                    "content": json.loads(chunk[6:])["choices"][0]["delta"] +                                }                      except:                          raise Exception(str(line[0])) diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 235fd906..0dd359e7 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -25,7 +25,7 @@ class ReplicateLLM(LLM):      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} +        return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}      def count_tokens(self, text: str):          return count_tokens(self.name, text) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py new file mode 100644 index 00000000..1cc0a711 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -0,0 +1,118 @@ +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + + +class TogetherLLM(LLM): +    # this is model-specific +    api_key: str +    model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" +    max_context_length: int = 2048 +    base_url: str = "https://api.together.xyz" +    verify_ssl: bool = True + +    _client_session: aiohttp.ClientSession = None + +    async def start(self, **kwargs): +        self._client_session = aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + +    async def stop(self): +        await self._client_session.close() + +    @property +    def name(self): +        return self.model + +    @property +    def context_length(self): +        return self.max_context_length + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} + +    def count_tokens(self, text: str): +        return count_tokens(self.name, text) + +    def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str: +        system_message = None +        if chat_messages[0]["role"] == "system": +            system_message = chat_messages.pop(0)["content"] + +        prompt = "\n" +        if system_message: +            prompt += f"<human>: Hi!\n<bot>: {system_message}\n" +        for message in chat_messages: +            prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n' +        return prompt + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream_tokens"] = True + +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + +        async with self._client_session.post(f"{self.base_url}/inference", json={ +            "prompt": self.convert_to_prompt(messages), +            **args +        }, headers={ +            "Authorization": f"Bearer {self.api_key}" +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        yield line.decode("utf-8") +                    except: +                        raise Exception(str(line)) + +    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +        args["stream_tokens"] = True + +        async with self._client_session.post(f"{self.base_url}/inference", json={ +            "prompt": self.convert_to_prompt(messages), +            **args +        }, headers={ +            "Authorization": f"Bearer {self.api_key}" +        }) as resp: +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                            continue +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                yield { +                                    "role": "assistant", +                                    "content": json.loads(chunk[6:])["choices"][0]["text"] +                                } +                    except: +                        raise Exception(str(line[0])) + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +        args = {**self.default_args, **kwargs} + +        messages = compile_chat_messages(args["model"], with_history, self.context_length, +                                         args["max_tokens"], prompt, functions=None, system_message=self.system_message) +        async with self._client_session.post(f"{self.base_url}/inference", json={ +            "prompt": self.convert_to_prompt(messages), +            **args +        }, headers={ +            "Authorization": f"Bearer {self.api_key}" +        }) as resp: +            try: +                return await resp.text() +            except: +                raise Exception(await resp.text())  | 
