diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/queued.py | 3 |
2 files changed, 7 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 3fbfdeed..dd1bdec2 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,6 +1,6 @@ import json import ssl -from typing import Any, Coroutine, List, Optional +from typing import Any, Callable, Coroutine, Dict, List, Optional import aiohttp @@ -19,7 +19,9 @@ class GGML(LLM): ca_bundle_path: str = None model: str = "ggml" - template_messages = llama2_template_messages + template_messages: Optional[ + Callable[[List[Dict[str, str]]], str] + ] = llama2_template_messages prompt_templates = { "edit": simplified_edit_prompt, diff --git a/continuedev/src/continuedev/libs/llm/queued.py b/continuedev/src/continuedev/libs/llm/queued.py index 9e6e0180..6dbaaa64 100644 --- a/continuedev/src/continuedev/libs/llm/queued.py +++ b/continuedev/src/continuedev/libs/llm/queued.py @@ -19,6 +19,9 @@ class QueuedLLM(LLM): await self.llm.start(*args, **kwargs) self._lock = asyncio.Lock() self.model = self.llm.model + self.template_messages = self.llm.template_messages + self.prompt_templates = self.llm.prompt_templates + self.context_length = self.llm.context_length async def stop(self): await self.llm.stop() |