diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 219 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/setup_model.py | 1 | 
2 files changed, 220 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py new file mode 100644 index 00000000..bdcf8612 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -0,0 +1,219 @@ +import json +from textwrap import dedent +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +import aiohttp + +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import ( +    DEFAULT_ARGS, +    compile_chat_messages, +    count_tokens, +    format_chat_messages, +) + + +def llama2_template_messages(msgs: ChatMessage) -> str: +    if len(msgs) == 0: +        return "" + +    prompt = "" +    has_system = msgs[0]["role"] == "system" +    if has_system: +        system_message = dedent( +            f"""\ +                <<SYS>> +                {msgs[0]["content"]} +                <</SYS>> +                 +                """ +        ) +        if len(msgs) > 1: +            prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" +        else: +            prompt += f"[INST] {system_message} [/INST]" +            return + +    for i in range(2 if has_system else 0, len(msgs)): +        if msgs[i]["role"] == "user": +            prompt += f"[INST] {msgs[i]['content']} [/INST]" +        else: +            prompt += msgs[i]["content"] + +    return prompt + + +def code_llama_template_messages(msgs: ChatMessage) -> str: +    return f"[INST] {msgs[-1]['content']} [/INST]" + + +def code_llama_python_template_messages(msgs: ChatMessage) -> str: +    return dedent( +        f"""\ +        [INST] +        You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} +        Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. +        [/INST]""" +    ) + + +class LlamaCpp(LLM): +    max_context_length: int = 2048 +    server_url: str = "http://localhost:8080" +    verify_ssl: Optional[bool] = None + +    template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages +    llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]} + +    requires_write_log = True +    write_log: Optional[Callable[[str], None]] = None + +    class Config: +        arbitrary_types_allowed = True + +    def dict(self, **kwargs): +        d = super().dict(**kwargs) +        d.pop("template_messages") +        return d + +    async def start(self, write_log: Callable[[str], None], **kwargs): +        self.write_log = write_log + +    async def stop(self): +        await self._client_session.close() + +    @property +    def name(self): +        return "llamacpp" + +    @property +    def context_length(self): +        return self.max_context_length + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + +    def count_tokens(self, text: str): +        return count_tokens(self.name, text) + +    def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: +        args = args.copy() +        if "max_tokens" in args: +            args["n_predict"] = args["max_tokens"] +            del args["max_tokens"] +        if "frequency_penalty" in args: +            del args["frequency_penalty"] +        if "presence_penalty" in args: +            del args["presence_penalty"] + +        for k, v in self.llama_cpp_args.items(): +            if k not in args: +                args[k] = v + +        return args + +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream"] = True + +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) + +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        completion = "" +        async with aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        ) as client_session: +            async with client_session.post( +                f"{self.server_url}/completion", +                json={ +                    "prompt": self.convert_to_chat(messages), +                    **self._transform_args(args), +                }, +                headers={"Content-Type": "application/json"}, +            ) as resp: +                async for line in resp.content.iter_any(): +                    if line: +                        chunk = line.decode("utf-8") +                        yield chunk +                        completion += chunk + +        self.write_log(f"Completion: \n\n{completion}") + +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) +        args["stream"] = True + +        prompt = self.template_messages(messages) +        headers = {"Content-Type": "application/json"} + +        async def generator(): +            async with aiohttp.ClientSession( +                connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +            ) as client_session: +                async with client_session.post( +                    f"{self.server_url}/completion", +                    json={"prompt": prompt, **self._transform_args(args)}, +                    headers=headers, +                ) as resp: +                    async for line in resp.content: +                        content = line.decode("utf-8") +                        if content.strip() == "": +                            continue +                        yield { +                            "content": json.loads(content[6:])["content"], +                            "role": "assistant", +                        } + +        # Because quite often the first attempt fails, and it works thereafter +        self.write_log(f"Prompt: \n\n{prompt}") +        completion = "" +        async for chunk in generator(): +            yield chunk +            if "content" in chunk: +                completion += chunk["content"] + +        self.write_log(f"Completion: \n\n{completion}") + +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]: +        args = {**self.default_args, **kwargs} + +        self.write_log(f"Prompt: \n\n{prompt}") +        async with aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        ) as client_session: +            async with client_session.post( +                f"{self.server_url}/completion", +                json={"prompt": prompt, **self._transform_args(args)}, +                headers={"Content-Type": "application/json"}, +            ) as resp: +                json_resp = await resp.json() +                completion = json_resp["content"] +                self.write_log(f"Completion: \n\n{completion}") +                return completion diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py index 1c50c714..2e1fdc14 100644 --- a/continuedev/src/continuedev/plugins/steps/setup_model.py +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -10,6 +10,7 @@ MODEL_CLASS_TO_MESSAGE = {      "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).",      "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!",      "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", +    "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.",  }  | 
