summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py219
-rw-r--r--continuedev/src/continuedev/plugins/steps/setup_model.py1
-rw-r--r--extension/react-app/src/components/ModelSelect.tsx5
3 files changed, 225 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
new file mode 100644
index 00000000..bdcf8612
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -0,0 +1,219 @@
+import json
+from textwrap import dedent
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+
+import aiohttp
+
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import (
+ DEFAULT_ARGS,
+ compile_chat_messages,
+ count_tokens,
+ format_chat_messages,
+)
+
+
+def llama2_template_messages(msgs: ChatMessage) -> str:
+ if len(msgs) == 0:
+ return ""
+
+ prompt = ""
+ has_system = msgs[0]["role"] == "system"
+ if has_system:
+ system_message = dedent(
+ f"""\
+ <<SYS>>
+ {msgs[0]["content"]}
+ <</SYS>>
+
+ """
+ )
+ if len(msgs) > 1:
+ prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
+ else:
+ prompt += f"[INST] {system_message} [/INST]"
+ return
+
+ for i in range(2 if has_system else 0, len(msgs)):
+ if msgs[i]["role"] == "user":
+ prompt += f"[INST] {msgs[i]['content']} [/INST]"
+ else:
+ prompt += msgs[i]["content"]
+
+ return prompt
+
+
+def code_llama_template_messages(msgs: ChatMessage) -> str:
+ return f"[INST] {msgs[-1]['content']} [/INST]"
+
+
+def code_llama_python_template_messages(msgs: ChatMessage) -> str:
+ return dedent(
+ f"""\
+ [INST]
+ You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']}
+ Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag.
+ [/INST]"""
+ )
+
+
+class LlamaCpp(LLM):
+ max_context_length: int = 2048
+ server_url: str = "http://localhost:8080"
+ verify_ssl: Optional[bool] = None
+
+ template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
+ llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]}
+
+ requires_write_log = True
+ write_log: Optional[Callable[[str], None]] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def dict(self, **kwargs):
+ d = super().dict(**kwargs)
+ d.pop("template_messages")
+ return d
+
+ async def start(self, write_log: Callable[[str], None], **kwargs):
+ self.write_log = write_log
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return "llamacpp"
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
+ args = args.copy()
+ if "max_tokens" in args:
+ args["n_predict"] = args["max_tokens"]
+ del args["max_tokens"]
+ if "frequency_penalty" in args:
+ del args["frequency_penalty"]
+ if "presence_penalty" in args:
+ del args["presence_penalty"]
+
+ for k, v in self.llama_cpp_args.items():
+ if k not in args:
+ args[k] = v
+
+ return args
+
+ async def stream_complete(
+ self, prompt, with_history: List[ChatMessage] = None, **kwargs
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name,
+ with_history,
+ self.context_length,
+ args["max_tokens"],
+ prompt,
+ functions=args.get("functions", None),
+ system_message=self.system_message,
+ )
+
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/completion",
+ json={
+ "prompt": self.convert_to_chat(messages),
+ **self._transform_args(args),
+ },
+ headers={"Content-Type": "application/json"},
+ ) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ chunk = line.decode("utf-8")
+ yield chunk
+ completion += chunk
+
+ self.write_log(f"Completion: \n\n{completion}")
+
+ async def stream_chat(
+ self, messages: List[ChatMessage] = None, **kwargs
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name,
+ messages,
+ self.context_length,
+ args["max_tokens"],
+ None,
+ functions=args.get("functions", None),
+ system_message=self.system_message,
+ )
+ args["stream"] = True
+
+ prompt = self.template_messages(messages)
+ headers = {"Content-Type": "application/json"}
+
+ async def generator():
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/completion",
+ json={"prompt": prompt, **self._transform_args(args)},
+ headers=headers,
+ ) as resp:
+ async for line in resp.content:
+ content = line.decode("utf-8")
+ if content.strip() == "":
+ continue
+ yield {
+ "content": json.loads(content[6:])["content"],
+ "role": "assistant",
+ }
+
+ # Because quite often the first attempt fails, and it works thereafter
+ self.write_log(f"Prompt: \n\n{prompt}")
+ completion = ""
+ async for chunk in generator():
+ yield chunk
+ if "content" in chunk:
+ completion += chunk["content"]
+
+ self.write_log(f"Completion: \n\n{completion}")
+
+ async def complete(
+ self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
+ ) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ self.write_log(f"Prompt: \n\n{prompt}")
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}/completion",
+ json={"prompt": prompt, **self._transform_args(args)},
+ headers={"Content-Type": "application/json"},
+ ) as resp:
+ json_resp = await resp.json()
+ completion = json_resp["content"]
+ self.write_log(f"Completion: \n\n{completion}")
+ return completion
diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py
index 1c50c714..2e1fdc14 100644
--- a/continuedev/src/continuedev/plugins/steps/setup_model.py
+++ b/continuedev/src/continuedev/plugins/steps/setup_model.py
@@ -10,6 +10,7 @@ MODEL_CLASS_TO_MESSAGE = {
"Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).",
"GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",
"TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.",
+ "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.",
}
diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx
index 49788143..9699847c 100644
--- a/extension/react-app/src/components/ModelSelect.tsx
+++ b/extension/react-app/src/components/ModelSelect.tsx
@@ -64,6 +64,11 @@ const MODEL_INFO: { title: string; class: string; args: any }[] = [
api_key: "<TOGETHER_API_KEY>",
},
},
+ {
+ title: "llama.cpp",
+ class: "LlamaCpp",
+ args: {},
+ },
];
const Select = styled.select`