summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/queued.py3
2 files changed, 7 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 3fbfdeed..dd1bdec2 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,6 @@
import json
import ssl
-from typing import Any, Coroutine, List, Optional
+from typing import Any, Callable, Coroutine, Dict, List, Optional
import aiohttp
@@ -19,7 +19,9 @@ class GGML(LLM):
ca_bundle_path: str = None
model: str = "ggml"
- template_messages = llama2_template_messages
+ template_messages: Optional[
+ Callable[[List[Dict[str, str]]], str]
+ ] = llama2_template_messages
prompt_templates = {
"edit": simplified_edit_prompt,
diff --git a/continuedev/src/continuedev/libs/llm/queued.py b/continuedev/src/continuedev/libs/llm/queued.py
index 9e6e0180..6dbaaa64 100644
--- a/continuedev/src/continuedev/libs/llm/queued.py
+++ b/continuedev/src/continuedev/libs/llm/queued.py
@@ -19,6 +19,9 @@ class QueuedLLM(LLM):
await self.llm.start(*args, **kwargs)
self._lock = asyncio.Lock()
self.model = self.llm.model
+ self.template_messages = self.llm.template_messages
+ self.prompt_templates = self.llm.prompt_templates
+ self.context_length = self.llm.context_length
async def stop(self):
await self.llm.stop()