summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/text_gen_interface.py
blob: 225fd3b614b6e86c3ad7462af72e1840400fb791 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
from typing import Any, Callable, Dict, List, Union

import websockets
from pydantic import Field

from ...core.main import ChatMessage
from .base import LLM
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplest_edit_prompt


class TextGenUI(LLM):
    """
    TextGenUI is a comprehensive, open-source language model UI and local server. You can set it up with an OpenAI-compatible server plugin, but if for some reason that doesn't work, you can use this class like so:

    ```python title="~/.continue/config.py"
    from continuedev.libs.llm.text_gen_interface import TextGenUI

    config = ContinueConfig(
        ...
        models=Models(
            default=TextGenUI(
                model="<MODEL_NAME>",
            )
        )
    )
    ```
    """

    model: str = "text-gen-ui"
    server_url: str = Field(
        "http://localhost:5000", description="URL of your TextGenUI server"
    )
    streaming_url: str = Field(
        "http://localhost:5005",
        description="URL of your TextGenUI streaming server (separate from main server URL)",
    )

    prompt_templates = {
        "edit": simplest_edit_prompt,
    }

    template_messages: Union[
        Callable[[List[Dict[str, str]]], str], None
    ] = llama2_template_messages

    class Config:
        arbitrary_types_allowed = True

    def collect_args(self, options) -> Any:
        args = super().collect_args(options)
        args = {**args, "max_new_tokens": options.max_tokens}
        args.pop("max_tokens", None)
        return args

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)

        ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
        payload = json.dumps({"prompt": prompt, "stream": True, **args})
        async with websockets.connect(
            f"{ws_url}/api/v1/stream", ping_interval=None
        ) as websocket:
            await websocket.send(payload)

            while True:
                incoming_data = await websocket.recv()
                incoming_data = json.loads(incoming_data)

                match incoming_data["event"]:
                    case "text_stream":
                        yield incoming_data["text"]
                    case "stream_end":
                        break

    async def _stream_chat(self, messages: List[ChatMessage], options):
        args = self.collect_args(options)

        async def generator():
            ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
            history = list(map(lambda x: x["content"], messages))
            payload = json.dumps(
                {
                    "user_input": messages[-1]["content"],
                    "history": {"internal": [history], "visible": [history]},
                    "stream": True,
                    **args,
                }
            )
            async with websockets.connect(
                f"{ws_url}/api/v1/chat-stream", ping_interval=None
            ) as websocket:
                await websocket.send(payload)

                prev = ""
                while True:
                    incoming_data = await websocket.recv()
                    incoming_data = json.loads(incoming_data)

                    match incoming_data["event"]:
                        case "text_stream":
                            visible = incoming_data["history"]["visible"][-1]
                            if len(visible) > 0:
                                yield {
                                    "role": "assistant",
                                    "content": visible[-1].replace(prev, ""),
                                }
                                prev = visible[-1]
                        case "stream_end":
                            break

        async for chunk in generator():
            yield chunk