diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 110 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/chat.py | 56 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 21 | ||||
-rw-r--r-- | extension/src/activation/environmentSetup.ts | 7 |
4 files changed, 115 insertions, 79 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index bdcf8612..9e424fde 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -1,5 +1,5 @@ +import asyncio import json -from textwrap import dedent from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union import aiohttp @@ -12,50 +12,7 @@ from ..util.count_tokens import ( count_tokens, format_chat_messages, ) - - -def llama2_template_messages(msgs: ChatMessage) -> str: - if len(msgs) == 0: - return "" - - prompt = "" - has_system = msgs[0]["role"] == "system" - if has_system: - system_message = dedent( - f"""\ - <<SYS>> - {msgs[0]["content"]} - <</SYS>> - - """ - ) - if len(msgs) > 1: - prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" - else: - prompt += f"[INST] {system_message} [/INST]" - return - - for i in range(2 if has_system else 0, len(msgs)): - if msgs[i]["role"] == "user": - prompt += f"[INST] {msgs[i]['content']} [/INST]" - else: - prompt += msgs[i]["content"] - - return prompt - - -def code_llama_template_messages(msgs: ChatMessage) -> str: - return f"[INST] {msgs[-1]['content']} [/INST]" - - -def code_llama_python_template_messages(msgs: ChatMessage) -> str: - return dedent( - f"""\ - [INST] - You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} - Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. - [/INST]""" - ) +from .prompts.chat import code_llama_template_messages class LlamaCpp(LLM): @@ -63,8 +20,10 @@ class LlamaCpp(LLM): server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None - template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages - llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]} + template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "} + + use_command: Optional[str] = None requires_write_log = True write_log: Optional[Callable[[str], None]] = None @@ -114,6 +73,23 @@ class LlamaCpp(LLM): return args + async def stream_from_main(self, prompt: str): + cmd = self.use_command.split(" ") + ["-p", prompt] + process = await asyncio.create_subprocess_exec( + *cmd, stdout=asyncio.subprocess.PIPE + ) + + total = "" + async for line in process.stdout: + chunk = line.decode().strip() + if "llama_print_timings" in total + chunk: + process.terminate() + return + total += chunk + yield chunk + + await process.wait() + async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: @@ -171,7 +147,7 @@ class LlamaCpp(LLM): prompt = self.template_messages(messages) headers = {"Content-Type": "application/json"} - async def generator(): + async def server_generator(): async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) as client_session: @@ -189,6 +165,12 @@ class LlamaCpp(LLM): "role": "assistant", } + async def command_generator(): + async for line in self.stream_from_main(prompt): + yield {"content": line, "role": "assistant"} + + generator = command_generator if self.use_command else server_generator + # Because quite often the first attempt fails, and it works thereafter self.write_log(f"Prompt: \n\n{prompt}") completion = "" @@ -205,15 +187,23 @@ class LlamaCpp(LLM): args = {**self.default_args, **kwargs} self.write_log(f"Prompt: \n\n{prompt}") - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - ) as client_session: - async with client_session.post( - f"{self.server_url}/completion", - json={"prompt": prompt, **self._transform_args(args)}, - headers={"Content-Type": "application/json"}, - ) as resp: - json_resp = await resp.json() - completion = json_resp["content"] - self.write_log(f"Completion: \n\n{completion}") - return completion + + if self.use_command: + completion = "" + async for line in self.stream_from_main(prompt): + completion += line + self.write_log(f"Completion: \n\n{completion}") + return completion + else: + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) as client_session: + async with client_session.post( + f"{self.server_url}/completion", + json={"prompt": prompt, **self._transform_args(args)}, + headers={"Content-Type": "application/json"}, + ) as resp: + json_resp = await resp.json() + completion = json_resp["content"] + self.write_log(f"Completion: \n\n{completion}") + return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py new file mode 100644 index 00000000..110dfaae --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +from ....core.main import ChatMessage + + +def llama2_template_messages(msgs: ChatMessage) -> str: + if len(msgs) == 0: + return "" + + prompt = "" + has_system = msgs[0]["role"] == "system" + + if has_system and msgs[0]["content"].strip() == "": + has_system = False + msgs = msgs[1:] + + if has_system: + system_message = dedent( + f"""\ + <<SYS>> + {msgs[0]["content"]} + <</SYS>> + + """ + ) + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" + else: + prompt += msgs[i]["content"] + + return prompt + + +def code_llama_template_messages(msgs: ChatMessage) -> str: + return f"[INST] {msgs[-1]['content']}\n[/INST]" + + +def extra_space_template_messages(msgs: ChatMessage) -> str: + return f" {msgs[-1]['content']}" + + +def code_llama_python_template_messages(msgs: ChatMessage) -> str: + return dedent( + f"""\ + [INST] + You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} + Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. + [/INST]""" + ) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 4baf0b6c..ddae91a9 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -6,6 +6,7 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import llama2_template_messages class TogetherLLM(LLM): @@ -41,20 +42,6 @@ class TogetherLLM(LLM): def count_tokens(self, text: str): return count_tokens(self.name, text) - def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str: - system_message = None - if chat_messages[0]["role"] == "system": - system_message = chat_messages.pop(0)["content"] - - prompt = "\n" - if system_message: - prompt += f"<human>: Hi!\n<bot>: {system_message}\n" - for message in chat_messages: - prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n' - - prompt += "<bot>:" - return prompt - async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: @@ -75,7 +62,7 @@ class TogetherLLM(LLM): async with self._client_session.post( f"{self.base_url}/inference", - json={"prompt": self.convert_to_prompt(messages), **args}, + json={"prompt": llama2_template_messages(messages), **args}, headers={"Authorization": f"Bearer {self.api_key}"}, ) as resp: async for line in resp.content.iter_any(): @@ -102,7 +89,7 @@ class TogetherLLM(LLM): async with self._client_session.post( f"{self.base_url}/inference", - json={"prompt": self.convert_to_prompt(messages), **args}, + json={"prompt": llama2_template_messages(messages), **args}, headers={"Authorization": f"Bearer {self.api_key}"}, ) as resp: async for line in resp.content.iter_chunks(): @@ -141,7 +128,7 @@ class TogetherLLM(LLM): ) async with self._client_session.post( f"{self.base_url}/inference", - json={"prompt": self.convert_to_prompt(messages), **args}, + json={"prompt": llama2_template_messages(messages), **args}, headers={"Authorization": f"Bearer {self.api_key}"}, ) as resp: try: diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts index 2067f0fb..6b434756 100644 --- a/extension/src/activation/environmentSetup.ts +++ b/extension/src/activation/environmentSetup.ts @@ -330,8 +330,11 @@ export async function startContinuePythonServer(redownload: boolean = true) { child.unref(); } } catch (e: any) { - console.log("Error starting server:", e); - retry(e); + if (attempts < maxAttempts) { + retry(e); + } else { + throw e; + } } }; |