summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/proxy_server.py
blob: 7c3462ebdce036dc6aa69a7f24d52ce9444bf546 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import traceback
from typing import List

import aiohttp

from ...core.main import ChatMessage
from ..util.telemetry import posthog_logger
from .base import LLM

# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"

MAX_TOKENS_FOR_MODEL = {
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-0613": 4096,
    "gpt-3.5-turbo-16k": 16384,
    "gpt-4": 8192,
}


class ProxyServer(LLM):
    _client_session: aiohttp.ClientSession

    class Config:
        arbitrary_types_allowed = True

    async def start(
        self,
        **kwargs,
    ):
        await super().start(**kwargs)
        self._client_session = self.create_client_session()

        self.context_length = MAX_TOKENS_FOR_MODEL[self.model]

    async def stop(self):
        await self._client_session.close()

    def get_headers(self):
        return {"unique_id": self.unique_id}

    async def _complete(self, prompt: str, options):
        args = self.collect_args(options)

        async with self._client_session.post(
            f"{SERVER_URL}/complete",
            json={"messages": [{"role": "user", "content": prompt}], **args},
            headers=self.get_headers(),
            proxy=self.proxy,
        ) as resp:
            resp_text = await resp.text()
            if resp.status != 200:
                raise Exception(resp_text)

            return resp_text

    async def _stream_chat(self, messages: List[ChatMessage], options):
        args = self.collect_args(options)
        async with self._client_session.post(
            f"{SERVER_URL}/stream_chat",
            json={"messages": messages, **args},
            headers=self.get_headers(),
            proxy=self.proxy,
        ) as resp:
            if resp.status != 200:
                raise Exception(await resp.text())

            async for line in resp.content.iter_chunks():
                if line[1]:
                    try:
                        json_chunk = line[0].decode("utf-8")
                        json_chunk = "{}" if json_chunk == "" else json_chunk
                        chunks = json_chunk.split("\n")
                        for chunk in chunks:
                            if chunk.strip() != "":
                                loaded_chunk = json.loads(chunk)
                                yield loaded_chunk

                    except Exception as e:
                        posthog_logger.capture_event(
                            "proxy_server_parse_error",
                            {
                                "error_title": "Proxy server stream_chat parsing failed",
                                "error_message": "\n".join(
                                    traceback.format_exception(e)
                                ),
                            },
                        )
                else:
                    break

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)

        async with self._client_session.post(
            f"{SERVER_URL}/stream_complete",
            json={"messages": [{"role": "user", "content": prompt}], **args},
            headers=self.get_headers(),
            proxy=self.proxy,
        ) as resp:
            if resp.status != 200:
                raise Exception(await resp.text())

            async for line in resp.content.iter_any():
                if line:
                    decoded_line = line.decode("utf-8")
                    yield decoded_line