from functools import cached_property import json from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens class Ollama(LLM): model: str = "llama2" server_url: str = "http://localhost:11434" max_context_length: int = 2048 _client_session: aiohttp.ClientSession = None class Config: arbitrary_types_allowed = True async def start(self, **kwargs): self._client_session = aiohttp.ClientSession() async def stop(self): await self._client_session.close() @property def name(self): return self.model @property def context_length(self) -> int: return self.max_context_length @property def default_args(self): return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} def count_tokens(self, text: str): return count_tokens(self.name, text) def convert_to_chat(self, msgs: ChatMessage) -> str: if len(msgs) == 0: return "" prompt = "" has_system = msgs[0]["role"] == "system" if has_system: system_message = f"""\ <> {self.system_message} <> """ if len(msgs) > 1: prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" else: prompt += f"[INST] {system_message} [/INST]" return for i in range(2 if has_system else 0, len(msgs)): if msgs[i]["role"] == "user": prompt += f"[INST] {msgs[i]['content']} [/INST]" else: prompt += msgs[i]['content'] return prompt async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) prompt = self.convert_to_chat(messages) async with self._client_session.post(f"{self.server_url}/api/generate", json={ "prompt": prompt, "model": self.model, }) as resp: async for line in resp.content.iter_any(): if line: try: json_chunk = line.decode("utf-8") chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": j = json.loads(chunk) if "response" in j: yield j["response"] except: raise Exception(str(line[0])) async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message) prompt = self.convert_to_chat(messages) async with self._client_session.post(f"{self.server_url}/api/generate", json={ "prompt": prompt, "model": self.model, }) as resp: # This is streaming application/json instaed of text/event-stream async for line in resp.content.iter_chunks(): if line[1]: try: json_chunk = line[0].decode("utf-8") chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": j = json.loads(chunk) if "response" in j: yield { "role": "assistant", "content": j["response"] } except: raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: completion = "" async with self._client_session.post(f"{self.server_url}/api/generate", json={ "prompt": prompt, "model": self.model, }) as resp: async for line in resp.content.iter_any(): if line: try: json_chunk = line.decode("utf-8") chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": j = json.loads(chunk) if "response" in j: completion += j["response"] except: raise Exception(str(line[0])) return completion