summaryrefslogtreecommitdiff
path: root/server/continuedev/core/models.py
blob: c31177b963db4a81160681fa19955bb077ce3ad0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List, Optional

from pydantic import BaseModel

from ..libs.llm.base import LLM
from ..libs.llm.llamacpp import LlamaCpp
from ..libs.llm.ollama import Ollama


class ContinueSDK(BaseModel):
    pass


ALL_MODEL_ROLES = [
    "default",
    "summarize",
    "edit",
    "chat",
]

MODEL_CLASSES = {
    cls.__name__: cls
    for cls in [
        Ollama,
        LlamaCpp
    ]
}

MODEL_MODULE_NAMES = {
    "Ollama": "ollama",
    "LlamaCpp": "llamacpp"
}


class Models(BaseModel):
    """Main class that holds the current model configuration"""

    default: LLM
    summarize: Optional[LLM] = None
    edit: Optional[LLM] = None
    chat: Optional[LLM] = None

    saved: List[LLM] = []

    # TODO namespace these away to not confuse readers,
    # or split Models into ModelsConfig, which gets turned into Models
    sdk: ContinueSDK = None

    def dict(self, **kwargs):
        original_dict = super().dict(**kwargs)
        original_dict.pop("sdk", None)
        return original_dict

    @property
    def all_models(self):
        models = [getattr(self, role) for role in ALL_MODEL_ROLES]
        return [model for model in models if model is not None]

    @property
    def system_message(self) -> Optional[str]:
        if self.sdk:
            return self.sdk.config.system_message
        return None

    def set_system_message(self, msg: str):
        for model in self.all_models:
            if model.system_message is None:
                model.system_message = msg

    async def start(self, sdk: "ContinueSDK"):
        """Start each of the LLMs, or fall back to default"""
        self.sdk = sdk

        for role in ALL_MODEL_ROLES:
            model = getattr(self, role)
            if model is None:
                setattr(self, role, self.default)
            else:
                await sdk.start_model(model)

        self.set_system_message(self.system_message)

    async def stop(self, sdk: "ContinueSDK"):
        """Stop each LLM (if it's not the default, which is shared)"""
        for model in self.all_models:
            await model.stop()