summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/ollama.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/ollama.py')
-rw-r--r--server/continuedev/libs/llm/ollama.py106
1 files changed, 106 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/ollama.py b/server/continuedev/libs/llm/ollama.py
new file mode 100644
index 00000000..82cbc852
--- /dev/null
+++ b/server/continuedev/libs/llm/ollama.py
@@ -0,0 +1,106 @@
+import json
+from typing import Callable
+
+import aiohttp
+from pydantic import Field
+
+from ...core.main import ContinueCustomException
+from ..util.logging import logger
+from .base import LLM
+from .prompts.chat import llama2_template_messages
+from .prompts.edit import simplified_edit_prompt
+
+
+class Ollama(LLM):
+ """
+ [Ollama](https://ollama.ai/) is an application for Mac and Linux that makes it easy to locally run open-source models, including Llama-2. Download the app from the website, and it will walk you through setup in a couple of minutes. You can also read more in their [README](https://github.com/jmorganca/ollama). Continue can then be configured to use the `Ollama` LLM class:
+
+ ```python title="~/.continue/config.py"
+ from continuedev.libs.llm.ollama import Ollama
+
+ config = ContinueConfig(
+ ...
+ models=Models(
+ default=Ollama(model="llama2")
+ )
+ )
+ ```
+ """
+
+ model: str = "llama2"
+ server_url: str = Field(
+ "http://localhost:11434", description="URL of the Ollama server"
+ )
+
+ _client_session: aiohttp.ClientSession = None
+
+ template_messages: Callable = llama2_template_messages
+
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ await super().start(**kwargs)
+ self._client_session = self.create_client_session()
+ try:
+ async with self._client_session.post(
+ f"{self.server_url}/api/generate",
+ proxy=self.proxy,
+ json={
+ "prompt": "",
+ "model": self.model,
+ },
+ ) as _:
+ pass
+ except Exception as e:
+ logger.warning(f"Error pre-loading Ollama model: {e}")
+
+ async def stop(self):
+ await self._client_session.close()
+
+ async def get_downloaded_models(self):
+ async with self._client_session.get(
+ f"{self.server_url}/api/tags",
+ proxy=self.proxy,
+ ) as resp:
+ js_data = await resp.json()
+ return list(map(lambda x: x["name"], js_data["models"]))
+
+ async def _stream_complete(self, prompt, options):
+ async with self._client_session.post(
+ f"{self.server_url}/api/generate",
+ json={
+ "template": prompt,
+ "model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": options.temperature},
+ },
+ proxy=self.proxy,
+ ) as resp:
+ if resp.status == 400:
+ txt = await resp.text()
+ extra_msg = ""
+ if "no such file" in txt:
+ extra_msg = f"\n\nThis means that the model '{self.model}' is not downloaded.\n\nYou have the following models downloaded: {', '.join(await self.get_downloaded_models())}.\n\nTo download this model, run `ollama run {self.model}` in your terminal."
+ raise ContinueCustomException(
+ f"Ollama returned an error: {txt}{extra_msg}",
+ "Invalid request to Ollama",
+ )
+ elif resp.status != 200:
+ raise ContinueCustomException(
+ f"Ollama returned an error: {await resp.text()}",
+ "Invalid request to Ollama",
+ )
+ async for line in resp.content.iter_any():
+ if line:
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ yield j["response"]