summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/ollama.py
blob: 82cbc8527661acd802e701d5ace66f20afcfcb54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
from typing import Callable

import aiohttp
from pydantic import Field

from ...core.main import ContinueCustomException
from ..util.logging import logger
from .base import LLM
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplified_edit_prompt


class Ollama(LLM):
    """
    [Ollama](https://ollama.ai/) is an application for Mac and Linux that makes it easy to locally run open-source models, including Llama-2. Download the app from the website, and it will walk you through setup in a couple of minutes. You can also read more in their [README](https://github.com/jmorganca/ollama). Continue can then be configured to use the `Ollama` LLM class:

    ```python title="~/.continue/config.py"
    from continuedev.libs.llm.ollama import Ollama

    config = ContinueConfig(
        ...
        models=Models(
            default=Ollama(model="llama2")
        )
    )
    ```
    """

    model: str = "llama2"
    server_url: str = Field(
        "http://localhost:11434", description="URL of the Ollama server"
    )

    _client_session: aiohttp.ClientSession = None

    template_messages: Callable = llama2_template_messages

    prompt_templates = {
        "edit": simplified_edit_prompt,
    }

    class Config:
        arbitrary_types_allowed = True

    async def start(self, **kwargs):
        await super().start(**kwargs)
        self._client_session = self.create_client_session()
        try:
            async with self._client_session.post(
                f"{self.server_url}/api/generate",
                proxy=self.proxy,
                json={
                    "prompt": "",
                    "model": self.model,
                },
            ) as _:
                pass
        except Exception as e:
            logger.warning(f"Error pre-loading Ollama model: {e}")

    async def stop(self):
        await self._client_session.close()

    async def get_downloaded_models(self):
        async with self._client_session.get(
            f"{self.server_url}/api/tags",
            proxy=self.proxy,
        ) as resp:
            js_data = await resp.json()
            return list(map(lambda x: x["name"], js_data["models"]))

    async def _stream_complete(self, prompt, options):
        async with self._client_session.post(
            f"{self.server_url}/api/generate",
            json={
                "template": prompt,
                "model": self.model,
                "system": self.system_message,
                "options": {"temperature": options.temperature},
            },
            proxy=self.proxy,
        ) as resp:
            if resp.status == 400:
                txt = await resp.text()
                extra_msg = ""
                if "no such file" in txt:
                    extra_msg = f"\n\nThis means that the model '{self.model}' is not downloaded.\n\nYou have the following models downloaded: {', '.join(await self.get_downloaded_models())}.\n\nTo download this model, run `ollama run {self.model}` in your terminal."
                raise ContinueCustomException(
                    f"Ollama returned an error: {txt}{extra_msg}",
                    "Invalid request to Ollama",
                )
            elif resp.status != 200:
                raise ContinueCustomException(
                    f"Ollama returned an error: {await resp.text()}",
                    "Invalid request to Ollama",
                )
            async for line in resp.content.iter_any():
                if line:
                    json_chunk = line.decode("utf-8")
                    chunks = json_chunk.split("\n")
                    for chunk in chunks:
                        if chunk.strip() != "":
                            j = json.loads(chunk)
                            if "response" in j:
                                yield j["response"]