diff options
Diffstat (limited to 'continuedev/src/continuedev/libs/llm/llamacpp.py')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 110 |
1 files changed, 50 insertions, 60 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index bdcf8612..9e424fde 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -1,5 +1,5 @@ +import asyncio import json -from textwrap import dedent from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union import aiohttp @@ -12,50 +12,7 @@ from ..util.count_tokens import ( count_tokens, format_chat_messages, ) - - -def llama2_template_messages(msgs: ChatMessage) -> str: - if len(msgs) == 0: - return "" - - prompt = "" - has_system = msgs[0]["role"] == "system" - if has_system: - system_message = dedent( - f"""\ - <<SYS>> - {msgs[0]["content"]} - <</SYS>> - - """ - ) - if len(msgs) > 1: - prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" - else: - prompt += f"[INST] {system_message} [/INST]" - return - - for i in range(2 if has_system else 0, len(msgs)): - if msgs[i]["role"] == "user": - prompt += f"[INST] {msgs[i]['content']} [/INST]" - else: - prompt += msgs[i]["content"] - - return prompt - - -def code_llama_template_messages(msgs: ChatMessage) -> str: - return f"[INST] {msgs[-1]['content']} [/INST]" - - -def code_llama_python_template_messages(msgs: ChatMessage) -> str: - return dedent( - f"""\ - [INST] - You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} - Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. - [/INST]""" - ) +from .prompts.chat import code_llama_template_messages class LlamaCpp(LLM): @@ -63,8 +20,10 @@ class LlamaCpp(LLM): server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None - template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages - llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]} + template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "} + + use_command: Optional[str] = None requires_write_log = True write_log: Optional[Callable[[str], None]] = None @@ -114,6 +73,23 @@ class LlamaCpp(LLM): return args + async def stream_from_main(self, prompt: str): + cmd = self.use_command.split(" ") + ["-p", prompt] + process = await asyncio.create_subprocess_exec( + *cmd, stdout=asyncio.subprocess.PIPE + ) + + total = "" + async for line in process.stdout: + chunk = line.decode().strip() + if "llama_print_timings" in total + chunk: + process.terminate() + return + total += chunk + yield chunk + + await process.wait() + async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: @@ -171,7 +147,7 @@ class LlamaCpp(LLM): prompt = self.template_messages(messages) headers = {"Content-Type": "application/json"} - async def generator(): + async def server_generator(): async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) as client_session: @@ -189,6 +165,12 @@ class LlamaCpp(LLM): "role": "assistant", } + async def command_generator(): + async for line in self.stream_from_main(prompt): + yield {"content": line, "role": "assistant"} + + generator = command_generator if self.use_command else server_generator + # Because quite often the first attempt fails, and it works thereafter self.write_log(f"Prompt: \n\n{prompt}") completion = "" @@ -205,15 +187,23 @@ class LlamaCpp(LLM): args = {**self.default_args, **kwargs} self.write_log(f"Prompt: \n\n{prompt}") - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - ) as client_session: - async with client_session.post( - f"{self.server_url}/completion", - json={"prompt": prompt, **self._transform_args(args)}, - headers={"Content-Type": "application/json"}, - ) as resp: - json_resp = await resp.json() - completion = json_resp["content"] - self.write_log(f"Completion: \n\n{completion}") - return completion + + if self.use_command: + completion = "" + async for line in self.stream_from_main(prompt): + completion += line + self.write_log(f"Completion: \n\n{completion}") + return completion + else: + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) as client_session: + async with client_session.post( + f"{self.server_url}/completion", + json={"prompt": prompt, **self._transform_args(args)}, + headers={"Content-Type": "application/json"}, + ) as resp: + json_resp = await resp.json() + completion = json_resp["content"] + self.write_log(f"Completion: \n\n{completion}") + return completion |