summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py110
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/chat.py56
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py21
-rw-r--r--extension/src/activation/environmentSetup.ts7
4 files changed, 115 insertions, 79 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index bdcf8612..9e424fde 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -1,5 +1,5 @@
+import asyncio
import json
-from textwrap import dedent
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
import aiohttp
@@ -12,50 +12,7 @@ from ..util.count_tokens import (
count_tokens,
format_chat_messages,
)
-
-
-def llama2_template_messages(msgs: ChatMessage) -> str:
- if len(msgs) == 0:
- return ""
-
- prompt = ""
- has_system = msgs[0]["role"] == "system"
- if has_system:
- system_message = dedent(
- f"""\
- <<SYS>>
- {msgs[0]["content"]}
- <</SYS>>
-
- """
- )
- if len(msgs) > 1:
- prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
- else:
- prompt += f"[INST] {system_message} [/INST]"
- return
-
- for i in range(2 if has_system else 0, len(msgs)):
- if msgs[i]["role"] == "user":
- prompt += f"[INST] {msgs[i]['content']} [/INST]"
- else:
- prompt += msgs[i]["content"]
-
- return prompt
-
-
-def code_llama_template_messages(msgs: ChatMessage) -> str:
- return f"[INST] {msgs[-1]['content']} [/INST]"
-
-
-def code_llama_python_template_messages(msgs: ChatMessage) -> str:
- return dedent(
- f"""\
- [INST]
- You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']}
- Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag.
- [/INST]"""
- )
+from .prompts.chat import code_llama_template_messages
class LlamaCpp(LLM):
@@ -63,8 +20,10 @@ class LlamaCpp(LLM):
server_url: str = "http://localhost:8080"
verify_ssl: Optional[bool] = None
- template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
- llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]}
+ template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+ llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "}
+
+ use_command: Optional[str] = None
requires_write_log = True
write_log: Optional[Callable[[str], None]] = None
@@ -114,6 +73,23 @@ class LlamaCpp(LLM):
return args
+ async def stream_from_main(self, prompt: str):
+ cmd = self.use_command.split(" ") + ["-p", prompt]
+ process = await asyncio.create_subprocess_exec(
+ *cmd, stdout=asyncio.subprocess.PIPE
+ )
+
+ total = ""
+ async for line in process.stdout:
+ chunk = line.decode().strip()
+ if "llama_print_timings" in total + chunk:
+ process.terminate()
+ return
+ total += chunk
+ yield chunk
+
+ await process.wait()
+
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
@@ -171,7 +147,7 @@ class LlamaCpp(LLM):
prompt = self.template_messages(messages)
headers = {"Content-Type": "application/json"}
- async def generator():
+ async def server_generator():
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
) as client_session:
@@ -189,6 +165,12 @@ class LlamaCpp(LLM):
"role": "assistant",
}
+ async def command_generator():
+ async for line in self.stream_from_main(prompt):
+ yield {"content": line, "role": "assistant"}
+
+ generator = command_generator if self.use_command else server_generator
+
# Because quite often the first attempt fails, and it works thereafter
self.write_log(f"Prompt: \n\n{prompt}")
completion = ""
@@ -205,15 +187,23 @@ class LlamaCpp(LLM):
args = {**self.default_args, **kwargs}
self.write_log(f"Prompt: \n\n{prompt}")
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
- ) as client_session:
- async with client_session.post(
- f"{self.server_url}/completion",
- json={"prompt": prompt, **self._transform_args(args)},
- headers={"Content-Type": "application/json"},
- ) as resp:
- json_resp = await resp.json()
- completion = json_resp["content"]
- self.write_log(f"Completion: \n\n{completion}")
- return completion
+
+ if self.use_command:
+ completion = ""
+ async for line in self.stream_from_main(prompt):
+ completion += line
+ self.write_log(f"Completion: \n\n{completion}")
+ return completion
+ else:
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/completion",
+ json={"prompt": prompt, **self._transform_args(args)},
+ headers={"Content-Type": "application/json"},
+ ) as resp:
+ json_resp = await resp.json()
+ completion = json_resp["content"]
+ self.write_log(f"Completion: \n\n{completion}")
+ return completion
diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py
new file mode 100644
index 00000000..110dfaae
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py
@@ -0,0 +1,56 @@
+from textwrap import dedent
+
+from ....core.main import ChatMessage
+
+
+def llama2_template_messages(msgs: ChatMessage) -> str:
+ if len(msgs) == 0:
+ return ""
+
+ prompt = ""
+ has_system = msgs[0]["role"] == "system"
+
+ if has_system and msgs[0]["content"].strip() == "":
+ has_system = False
+ msgs = msgs[1:]
+
+ if has_system:
+ system_message = dedent(
+ f"""\
+ <<SYS>>
+ {msgs[0]["content"]}
+ <</SYS>>
+
+ """
+ )
+ if len(msgs) > 1:
+ prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
+ else:
+ prompt += f"[INST] {system_message} [/INST]"
+ return
+
+ for i in range(2 if has_system else 0, len(msgs)):
+ if msgs[i]["role"] == "user":
+ prompt += f"[INST] {msgs[i]['content']} [/INST]"
+ else:
+ prompt += msgs[i]["content"]
+
+ return prompt
+
+
+def code_llama_template_messages(msgs: ChatMessage) -> str:
+ return f"[INST] {msgs[-1]['content']}\n[/INST]"
+
+
+def extra_space_template_messages(msgs: ChatMessage) -> str:
+ return f" {msgs[-1]['content']}"
+
+
+def code_llama_python_template_messages(msgs: ChatMessage) -> str:
+ return dedent(
+ f"""\
+ [INST]
+ You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']}
+ Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag.
+ [/INST]"""
+ )
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index 4baf0b6c..ddae91a9 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -6,6 +6,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from .prompts.chat import llama2_template_messages
class TogetherLLM(LLM):
@@ -41,20 +42,6 @@ class TogetherLLM(LLM):
def count_tokens(self, text: str):
return count_tokens(self.name, text)
- def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str:
- system_message = None
- if chat_messages[0]["role"] == "system":
- system_message = chat_messages.pop(0)["content"]
-
- prompt = "\n"
- if system_message:
- prompt += f"<human>: Hi!\n<bot>: {system_message}\n"
- for message in chat_messages:
- prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n'
-
- prompt += "<bot>:"
- return prompt
-
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
@@ -75,7 +62,7 @@ class TogetherLLM(LLM):
async with self._client_session.post(
f"{self.base_url}/inference",
- json={"prompt": self.convert_to_prompt(messages), **args},
+ json={"prompt": llama2_template_messages(messages), **args},
headers={"Authorization": f"Bearer {self.api_key}"},
) as resp:
async for line in resp.content.iter_any():
@@ -102,7 +89,7 @@ class TogetherLLM(LLM):
async with self._client_session.post(
f"{self.base_url}/inference",
- json={"prompt": self.convert_to_prompt(messages), **args},
+ json={"prompt": llama2_template_messages(messages), **args},
headers={"Authorization": f"Bearer {self.api_key}"},
) as resp:
async for line in resp.content.iter_chunks():
@@ -141,7 +128,7 @@ class TogetherLLM(LLM):
)
async with self._client_session.post(
f"{self.base_url}/inference",
- json={"prompt": self.convert_to_prompt(messages), **args},
+ json={"prompt": llama2_template_messages(messages), **args},
headers={"Authorization": f"Bearer {self.api_key}"},
) as resp:
try:
diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts
index 2067f0fb..6b434756 100644
--- a/extension/src/activation/environmentSetup.ts
+++ b/extension/src/activation/environmentSetup.ts
@@ -330,8 +330,11 @@ export async function startContinuePythonServer(redownload: boolean = true) {
child.unref();
}
} catch (e: any) {
- console.log("Error starting server:", e);
- retry(e);
+ if (attempts < maxAttempts) {
+ retry(e);
+ } else {
+ throw e;
+ }
}
};