diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 110 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/chat.py | 56 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 21 | 
3 files changed, 110 insertions, 77 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index bdcf8612..9e424fde 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -1,5 +1,5 @@ +import asyncio  import json -from textwrap import dedent  from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union  import aiohttp @@ -12,50 +12,7 @@ from ..util.count_tokens import (      count_tokens,      format_chat_messages,  ) - - -def llama2_template_messages(msgs: ChatMessage) -> str: -    if len(msgs) == 0: -        return "" - -    prompt = "" -    has_system = msgs[0]["role"] == "system" -    if has_system: -        system_message = dedent( -            f"""\ -                <<SYS>> -                {msgs[0]["content"]} -                <</SYS>> -                 -                """ -        ) -        if len(msgs) > 1: -            prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" -        else: -            prompt += f"[INST] {system_message} [/INST]" -            return - -    for i in range(2 if has_system else 0, len(msgs)): -        if msgs[i]["role"] == "user": -            prompt += f"[INST] {msgs[i]['content']} [/INST]" -        else: -            prompt += msgs[i]["content"] - -    return prompt - - -def code_llama_template_messages(msgs: ChatMessage) -> str: -    return f"[INST] {msgs[-1]['content']} [/INST]" - - -def code_llama_python_template_messages(msgs: ChatMessage) -> str: -    return dedent( -        f"""\ -        [INST] -        You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} -        Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. -        [/INST]""" -    ) +from .prompts.chat import code_llama_template_messages  class LlamaCpp(LLM): @@ -63,8 +20,10 @@ class LlamaCpp(LLM):      server_url: str = "http://localhost:8080"      verify_ssl: Optional[bool] = None -    template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages -    llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]} +    template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages +    llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "} + +    use_command: Optional[str] = None      requires_write_log = True      write_log: Optional[Callable[[str], None]] = None @@ -114,6 +73,23 @@ class LlamaCpp(LLM):          return args +    async def stream_from_main(self, prompt: str): +        cmd = self.use_command.split(" ") + ["-p", prompt] +        process = await asyncio.create_subprocess_exec( +            *cmd, stdout=asyncio.subprocess.PIPE +        ) + +        total = "" +        async for line in process.stdout: +            chunk = line.decode().strip() +            if "llama_print_timings" in total + chunk: +                process.terminate() +                return +            total += chunk +            yield chunk + +        await process.wait() +      async def stream_complete(          self, prompt, with_history: List[ChatMessage] = None, **kwargs      ) -> Generator[Union[Any, List, Dict], None, None]: @@ -171,7 +147,7 @@ class LlamaCpp(LLM):          prompt = self.template_messages(messages)          headers = {"Content-Type": "application/json"} -        async def generator(): +        async def server_generator():              async with aiohttp.ClientSession(                  connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)              ) as client_session: @@ -189,6 +165,12 @@ class LlamaCpp(LLM):                              "role": "assistant",                          } +        async def command_generator(): +            async for line in self.stream_from_main(prompt): +                yield {"content": line, "role": "assistant"} + +        generator = command_generator if self.use_command else server_generator +          # Because quite often the first attempt fails, and it works thereafter          self.write_log(f"Prompt: \n\n{prompt}")          completion = "" @@ -205,15 +187,23 @@ class LlamaCpp(LLM):          args = {**self.default_args, **kwargs}          self.write_log(f"Prompt: \n\n{prompt}") -        async with aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) -        ) as client_session: -            async with client_session.post( -                f"{self.server_url}/completion", -                json={"prompt": prompt, **self._transform_args(args)}, -                headers={"Content-Type": "application/json"}, -            ) as resp: -                json_resp = await resp.json() -                completion = json_resp["content"] -                self.write_log(f"Completion: \n\n{completion}") -                return completion + +        if self.use_command: +            completion = "" +            async for line in self.stream_from_main(prompt): +                completion += line +            self.write_log(f"Completion: \n\n{completion}") +            return completion +        else: +            async with aiohttp.ClientSession( +                connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +            ) as client_session: +                async with client_session.post( +                    f"{self.server_url}/completion", +                    json={"prompt": prompt, **self._transform_args(args)}, +                    headers={"Content-Type": "application/json"}, +                ) as resp: +                    json_resp = await resp.json() +                    completion = json_resp["content"] +                    self.write_log(f"Completion: \n\n{completion}") +                    return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py new file mode 100644 index 00000000..110dfaae --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +from ....core.main import ChatMessage + + +def llama2_template_messages(msgs: ChatMessage) -> str: +    if len(msgs) == 0: +        return "" + +    prompt = "" +    has_system = msgs[0]["role"] == "system" + +    if has_system and msgs[0]["content"].strip() == "": +        has_system = False +        msgs = msgs[1:] + +    if has_system: +        system_message = dedent( +            f"""\ +                <<SYS>> +                {msgs[0]["content"]} +                <</SYS>> +                 +                """ +        ) +        if len(msgs) > 1: +            prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" +        else: +            prompt += f"[INST] {system_message} [/INST]" +            return + +    for i in range(2 if has_system else 0, len(msgs)): +        if msgs[i]["role"] == "user": +            prompt += f"[INST] {msgs[i]['content']} [/INST]" +        else: +            prompt += msgs[i]["content"] + +    return prompt + + +def code_llama_template_messages(msgs: ChatMessage) -> str: +    return f"[INST] {msgs[-1]['content']}\n[/INST]" + + +def extra_space_template_messages(msgs: ChatMessage) -> str: +    return f" {msgs[-1]['content']}" + + +def code_llama_python_template_messages(msgs: ChatMessage) -> str: +    return dedent( +        f"""\ +        [INST] +        You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} +        Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. +        [/INST]""" +    ) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 4baf0b6c..ddae91a9 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -6,6 +6,7 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM  from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import llama2_template_messages  class TogetherLLM(LLM): @@ -41,20 +42,6 @@ class TogetherLLM(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str: -        system_message = None -        if chat_messages[0]["role"] == "system": -            system_message = chat_messages.pop(0)["content"] - -        prompt = "\n" -        if system_message: -            prompt += f"<human>: Hi!\n<bot>: {system_message}\n" -        for message in chat_messages: -            prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n' - -        prompt += "<bot>:" -        return prompt -      async def stream_complete(          self, prompt, with_history: List[ChatMessage] = None, **kwargs      ) -> Generator[Union[Any, List, Dict], None, None]: @@ -75,7 +62,7 @@ class TogetherLLM(LLM):          async with self._client_session.post(              f"{self.base_url}/inference", -            json={"prompt": self.convert_to_prompt(messages), **args}, +            json={"prompt": llama2_template_messages(messages), **args},              headers={"Authorization": f"Bearer {self.api_key}"},          ) as resp:              async for line in resp.content.iter_any(): @@ -102,7 +89,7 @@ class TogetherLLM(LLM):          async with self._client_session.post(              f"{self.base_url}/inference", -            json={"prompt": self.convert_to_prompt(messages), **args}, +            json={"prompt": llama2_template_messages(messages), **args},              headers={"Authorization": f"Bearer {self.api_key}"},          ) as resp:              async for line in resp.content.iter_chunks(): @@ -141,7 +128,7 @@ class TogetherLLM(LLM):          )          async with self._client_session.post(              f"{self.base_url}/inference", -            json={"prompt": self.convert_to_prompt(messages), **args}, +            json={"prompt": llama2_template_messages(messages), **args},              headers={"Authorization": f"Bearer {self.api_key}"},          ) as resp:              try: | 
