diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py new file mode 100644 index 00000000..86da4115 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -0,0 +1,116 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + + +class Ollama(LLM): + model: str = "llama2" + server_url: str = "http://localhost:11434" + max_context_length: int + + @cached_property + def name(self): + return self.model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + def convert_to_chat(self, msgs: ChatMessage) -> str: + if len(msgs) == 0: + return "" + + prompt = "" + has_system = msgs[0].role == "system" + if has_system: + system_message = f"""\ + <<SYS>> + {self.system_message} + <</SYS>> + + """ + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1].content} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i].role == "user": + prompt += f"[INST] {msgs[i].content} [/INST]" + else: + prompt += msgs[i].content + + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + messages = compile_chat_messages( + self.name, with_history, self.max_context_length, prompt, system_message=self.system_message) + prompt = self.convert_to_chat(messages) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.server_urlL}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + j = json.dumps(line.decode("utf-8")) + yield j["response"] + if j["done"]: + break + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + messages = compile_chat_messages( + self.name, messages, self.max_context_length, prompt, system_message=self.system_message) + prompt = self.convert_to_chat(messages) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.server_urlL}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + j = json.dumps(line.decode("utf-8")) + yield { + "role": "assistant", + "content": j["response"] + } + if j["done"]: + break + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + completion = "" + + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.server_urlL}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + j = json.dumps(line.decode("utf-8")) + completion += j["response"] + if j["done"]: + break + except: + raise Exception(str(line)) + + return completion |