diff options
Diffstat (limited to 'server/continuedev/libs/llm/ollama.py')
-rw-r--r-- | server/continuedev/libs/llm/ollama.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/ollama.py b/server/continuedev/libs/llm/ollama.py new file mode 100644 index 00000000..82cbc852 --- /dev/null +++ b/server/continuedev/libs/llm/ollama.py @@ -0,0 +1,106 @@ +import json +from typing import Callable + +import aiohttp +from pydantic import Field + +from ...core.main import ContinueCustomException +from ..util.logging import logger +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class Ollama(LLM): + """ + [Ollama](https://ollama.ai/) is an application for Mac and Linux that makes it easy to locally run open-source models, including Llama-2. Download the app from the website, and it will walk you through setup in a couple of minutes. You can also read more in their [README](https://github.com/jmorganca/ollama). Continue can then be configured to use the `Ollama` LLM class: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.ollama import Ollama + + config = ContinueConfig( + ... + models=Models( + default=Ollama(model="llama2") + ) + ) + ``` + """ + + model: str = "llama2" + server_url: str = Field( + "http://localhost:11434", description="URL of the Ollama server" + ) + + _client_session: aiohttp.ClientSession = None + + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client_session = self.create_client_session() + try: + async with self._client_session.post( + f"{self.server_url}/api/generate", + proxy=self.proxy, + json={ + "prompt": "", + "model": self.model, + }, + ) as _: + pass + except Exception as e: + logger.warning(f"Error pre-loading Ollama model: {e}") + + async def stop(self): + await self._client_session.close() + + async def get_downloaded_models(self): + async with self._client_session.get( + f"{self.server_url}/api/tags", + proxy=self.proxy, + ) as resp: + js_data = await resp.json() + return list(map(lambda x: x["name"], js_data["models"])) + + async def _stream_complete(self, prompt, options): + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "template": prompt, + "model": self.model, + "system": self.system_message, + "options": {"temperature": options.temperature}, + }, + proxy=self.proxy, + ) as resp: + if resp.status == 400: + txt = await resp.text() + extra_msg = "" + if "no such file" in txt: + extra_msg = f"\n\nThis means that the model '{self.model}' is not downloaded.\n\nYou have the following models downloaded: {', '.join(await self.get_downloaded_models())}.\n\nTo download this model, run `ollama run {self.model}` in your terminal." + raise ContinueCustomException( + f"Ollama returned an error: {txt}{extra_msg}", + "Invalid request to Ollama", + ) + elif resp.status != 200: + raise ContinueCustomException( + f"Ollama returned an error: {await resp.text()}", + "Invalid request to Ollama", + ) + async for line in resp.content.iter_any(): + if line: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield j["response"] |