diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-09-07 00:30:02 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-09-07 00:30:02 -0700 | 
| commit | 7daf1e4ae8c5de8732d49233e4efcaf25bec91d7 (patch) | |
| tree | 62d65117ac6ce22fb6084bff37cffc4324de1b22 /continuedev | |
| parent | bc8a559563f4aab88fa6d7d8ecd3109656be9f30 (diff) | |
| download | sncontinue-7daf1e4ae8c5de8732d49233e4efcaf25bec91d7.tar.gz sncontinue-7daf1e4ae8c5de8732d49233e4efcaf25bec91d7.tar.bz2 sncontinue-7daf1e4ae8c5de8732d49233e4efcaf25bec91d7.zip | |
refactor: :art: template_messages for GGML
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 34 | 
1 files changed, 23 insertions, 11 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index a183e643..3fbfdeed 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -9,6 +9,7 @@ from ..llm import LLM  from ..util.logging import logger  from . import CompletionOptions  from .openai import CHAT_MODELS +from .prompts.chat import llama2_template_messages  from .prompts.edit import simplified_edit_prompt @@ -18,6 +19,8 @@ class GGML(LLM):      ca_bundle_path: str = None      model: str = "ggml" +    template_messages = llama2_template_messages +      prompt_templates = {          "edit": simplified_edit_prompt,      } @@ -63,17 +66,26 @@ class GGML(LLM):              ) as resp:                  async for line in resp.content.iter_any():                      if line: -                        chunk = line.decode("utf-8") -                        if chunk.startswith(": ping - ") or chunk.startswith( -                            "data: [DONE]" -                        ): -                            continue -                        elif chunk.startswith("data: "): -                            chunk = chunk[6:] - -                        j = json.loads(chunk) -                        if "choices" in j: -                            yield j["choices"][0]["text"] +                        chunks = line.decode("utf-8") +                        for chunk in chunks.split("\n"): +                            if ( +                                chunk.startswith(": ping - ") +                                or chunk.startswith("data: [DONE]") +                                or chunk.strip() == "" +                            ): +                                continue +                            elif chunk.startswith("data: "): +                                chunk = chunk[6:] +                            try: +                                j = json.loads(chunk) +                            except Exception: +                                continue +                            if ( +                                "choices" in j +                                and len(j["choices"]) > 0 +                                and "text" in j["choices"][0] +                            ): +                                yield j["choices"][0]["text"]      async def _stream_chat(self, messages: List[ChatMessage], options):          args = self.collect_args(options) | 
