diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-08-11 13:47:11 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-11 13:47:11 -0700 |
commit | 49e323bd9312e49e7149c8851d55241b4e24ef18 (patch) | |
tree | 41b89dba1c67e479236bdaa86d6606f99381c822 /continuedev/src | |
parent | 48ee1334dfd21dbe55cf66f39da1249619103e81 (diff) | |
parent | 34f77a7344bc527e0c08dea5820a01748f2f8481 (diff) | |
download | sncontinue-49e323bd9312e49e7149c8851d55241b4e24ef18.tar.gz sncontinue-49e323bd9312e49e7149c8851d55241b4e24ef18.tar.bz2 sncontinue-49e323bd9312e49e7149c8851d55241b4e24ef18.zip |
Merge pull request #369 from continuedev/ci-testing
Ci testing
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/replicate.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 122 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/help.py | 3 |
4 files changed, 129 insertions, 3 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 2f131354..25a61e63 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -82,7 +82,10 @@ class GGML(LLM): chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": - yield json.loads(chunk[6:])["choices"][0]["delta"] + yield { + "role": "assistant", + "content": json.loads(chunk[6:])["choices"][0]["delta"] + } except: raise Exception(str(line[0])) diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 235fd906..0dd359e7 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -25,7 +25,7 @@ class ReplicateLLM(LLM): @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} def count_tokens(self, text: str): return count_tokens(self.name, text) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py new file mode 100644 index 00000000..c3f171c9 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -0,0 +1,122 @@ +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + + +class TogetherLLM(LLM): + # this is model-specific + api_key: str + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" + max_context_length: int = 2048 + base_url: str = "https://api.together.xyz" + verify_ssl: bool = True + + _client_session: aiohttp.ClientSession = None + + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + + async def stop(self): + await self._client_session.close() + + @property + def name(self): + return self.model + + @property + def context_length(self): + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str: + system_message = None + if chat_messages[0]["role"] == "system": + system_message = chat_messages.pop(0)["content"] + + prompt = "\n" + if system_message: + prompt += f"<human>: Hi!\n<bot>: {system_message}\n" + for message in chat_messages: + prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n' + + prompt += "<bot>:" + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream_tokens"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["stream_tokens"] = True + + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield { + "role": "assistant", + "content": json.loads(chunk[6:])["choices"][0]["text"] + } + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + messages = compile_chat_messages(args["model"], with_history, self.context_length, + args["max_tokens"], prompt, functions=None, system_message=self.system_message) + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + try: + text = await resp.text() + j = json.loads(text) + return j["output"]["choices"][0]["text"] + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index ec670999..82f885d6 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -39,6 +39,7 @@ class HelpStep(Step): if question.strip() == "": self.description = help else: + self.description = "The following output is generated by a language model, which may hallucinate. Type just '/help'to see a fixed answer. You can also learn more by reading [the docs](https://continue.dev/docs).\n\n" prompt = dedent(f""" Information: @@ -48,7 +49,7 @@ class HelpStep(Step): Please us the information below to provide a succinct answer to the following question: {question} - Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback.""") + Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""") self.chat_context.append(ChatMessage( role="user", |