1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
import json
import traceback
from typing import List
import aiohttp
from ...core.main import ChatMessage
from ..util.telemetry import posthog_logger
from .base import LLM
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
}
class ProxyServer(LLM):
_client_session: aiohttp.ClientSession
class Config:
arbitrary_types_allowed = True
async def start(
self,
**kwargs,
):
await super().start(**kwargs)
self._client_session = self.create_client_session()
self.context_length = MAX_TOKENS_FOR_MODEL[self.model]
async def stop(self):
await self._client_session.close()
def get_headers(self):
return {"unique_id": self.unique_id}
async def _complete(self, prompt: str, options):
args = self.collect_args(options)
async with self._client_session.post(
f"{SERVER_URL}/complete",
json={"messages": [{"role": "user", "content": prompt}], **args},
headers=self.get_headers(),
proxy=self.proxy,
) as resp:
resp_text = await resp.text()
if resp.status != 200:
raise Exception(resp_text)
return resp_text
async def _stream_chat(self, messages: List[ChatMessage], options):
args = self.collect_args(options)
async with self._client_session.post(
f"{SERVER_URL}/stream_chat",
json={"messages": messages, **args},
headers=self.get_headers(),
proxy=self.proxy,
) as resp:
if resp.status != 200:
raise Exception(await resp.text())
async for line in resp.content.iter_chunks():
if line[1]:
try:
json_chunk = line[0].decode("utf-8")
json_chunk = "{}" if json_chunk == "" else json_chunk
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
loaded_chunk = json.loads(chunk)
yield loaded_chunk
except Exception as e:
posthog_logger.capture_event(
"proxy_server_parse_error",
{
"error_title": "Proxy server stream_chat parsing failed",
"error_message": "\n".join(
traceback.format_exception(e)
),
},
)
else:
break
async def _stream_complete(self, prompt, options):
args = self.collect_args(options)
async with self._client_session.post(
f"{SERVER_URL}/stream_complete",
json={"messages": [{"role": "user", "content": prompt}], **args},
headers=self.get_headers(),
proxy=self.proxy,
) as resp:
if resp.status != 200:
raise Exception(await resp.text())
async for line in resp.content.iter_any():
if line:
decoded_line = line.decode("utf-8")
yield decoded_line
|