1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
import json
from typing import Any, Callable, Dict, List, Union
import websockets
from pydantic import Field
from ...core.main import ChatMessage
from .base import LLM
from .prompts.chat import llama2_template_messages
from .prompts.edit import simplest_edit_prompt
class TextGenUI(LLM):
"""
TextGenUI is a comprehensive, open-source language model UI and local server. You can set it up with an OpenAI-compatible server plugin, but if for some reason that doesn't work, you can use this class like so:
```python title="~/.continue/config.py"
from continuedev.libs.llm.text_gen_interface import TextGenUI
config = ContinueConfig(
...
models=Models(
default=TextGenUI(
model="<MODEL_NAME>",
)
)
)
```
"""
model: str = "text-gen-ui"
server_url: str = Field(
"http://localhost:5000", description="URL of your TextGenUI server"
)
streaming_url: str = Field(
"http://localhost:5005",
description="URL of your TextGenUI streaming server (separate from main server URL)",
)
prompt_templates = {
"edit": simplest_edit_prompt,
}
template_messages: Union[
Callable[[List[Dict[str, str]]], str], None
] = llama2_template_messages
class Config:
arbitrary_types_allowed = True
def collect_args(self, options) -> Any:
args = super().collect_args(options)
args = {**args, "max_new_tokens": options.max_tokens}
args.pop("max_tokens", None)
return args
async def _stream_complete(self, prompt, options):
args = self.collect_args(options)
ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
payload = json.dumps({"prompt": prompt, "stream": True, **args})
async with websockets.connect(
f"{ws_url}/api/v1/stream", ping_interval=None
) as websocket:
await websocket.send(payload)
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data["event"]:
case "text_stream":
yield incoming_data["text"]
case "stream_end":
break
async def _stream_chat(self, messages: List[ChatMessage], options):
args = self.collect_args(options)
async def generator():
ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
history = list(map(lambda x: x["content"], messages))
payload = json.dumps(
{
"user_input": messages[-1]["content"],
"history": {"internal": [history], "visible": [history]},
"stream": True,
**args,
}
)
async with websockets.connect(
f"{ws_url}/api/v1/chat-stream", ping_interval=None
) as websocket:
await websocket.send(payload)
prev = ""
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data["event"]:
case "text_stream":
visible = incoming_data["history"]["visible"][-1]
if len(visible) > 0:
yield {
"role": "assistant",
"content": visible[-1].replace(prev, ""),
}
prev = visible[-1]
case "stream_end":
break
async for chunk in generator():
yield chunk
|