import json
import ssl
import traceback
from typing import (
    Any,
    Callable,
    Coroutine,
    Dict,
    Generator,
    List,
    Optional,
    Union,
)

import aiohttp
import certifi

from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import (
    DEFAULT_ARGS,
    compile_chat_messages,
    count_tokens,
    format_chat_messages,
)
from ..util.telemetry import posthog_logger

ca_bundle_path = certifi.where()
ssl_context = ssl.create_default_context(cafile=ca_bundle_path)

# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"

MAX_TOKENS_FOR_MODEL = {
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-0613": 4096,
    "gpt-3.5-turbo-16k": 16384,
    "gpt-4": 8192,
}


class ProxyServer(LLM):
    model: str
    system_message: Optional[str]

    unique_id: str = None
    write_log: Callable[[str], None] = None
    _client_session: aiohttp.ClientSession

    requires_unique_id = True
    requires_write_log = True

    class Config:
        arbitrary_types_allowed = True

    async def start(
        self,
        *,
        api_key: Optional[str] = None,
        write_log: Callable[[str], None],
        unique_id: str,
        **kwargs,
    ):
        self._client_session = aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(ssl_context=ssl_context)
        )
        self.write_log = write_log
        self.unique_id = unique_id

    async def stop(self):
        await self._client_session.close()

    @property
    def name(self):
        return self.model

    @property
    def context_length(self):
        return MAX_TOKENS_FOR_MODEL[self.model]

    @property
    def default_args(self):
        return {**DEFAULT_ARGS, "model": self.model}

    def count_tokens(self, text: str):
        return count_tokens(self.model, text)

    def get_headers(self):
        # headers with unique id
        return {"unique_id": self.unique_id}

    async def complete(
        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
    ) -> Coroutine[Any, Any, str]:
        args = {**self.default_args, **kwargs}

        messages = compile_chat_messages(
            args["model"],
            with_history,
            self.context_length,
            args["max_tokens"],
            prompt,
            functions=None,
            system_message=self.system_message,
        )
        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
        async with self._client_session.post(
            f"{SERVER_URL}/complete",
            json={"messages": messages, **args},
            headers=self.get_headers(),
        ) as resp:
            if resp.status != 200:
                raise Exception(await resp.text())

            response_text = await resp.text()
            self.write_log(f"Completion: \n\n{response_text}")
            return response_text

    async def stream_chat(
        self, messages: List[ChatMessage] = None, **kwargs
    ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
        args = {**self.default_args, **kwargs}
        messages = compile_chat_messages(
            args["model"],
            messages,
            self.context_length,
            args["max_tokens"],
            None,
            functions=args.get("functions", None),
            system_message=self.system_message,
        )
        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")

        async with self._client_session.post(
            f"{SERVER_URL}/stream_chat",
            json={"messages": messages, **args},
            headers=self.get_headers(),
        ) as resp:
            # This is streaming application/json instaed of text/event-stream
            completion = ""
            if resp.status != 200:
                raise Exception(await resp.text())
            async for line in resp.content.iter_chunks():
                if line[1]:
                    try:
                        json_chunk = line[0].decode("utf-8")
                        json_chunk = "{}" if json_chunk == "" else json_chunk
                        chunks = json_chunk.split("\n")
                        for chunk in chunks:
                            if chunk.strip() != "":
                                loaded_chunk = json.loads(chunk)
                                yield loaded_chunk
                                if "content" in loaded_chunk:
                                    completion += loaded_chunk["content"]
                    except Exception as e:
                        posthog_logger.capture_event(
                            "proxy_server_parse_error",
                            {
                                "error_title": "Proxy server stream_chat parsing failed",
                                "error_message": "\n".join(
                                    traceback.format_exception(e)
                                ),
                            },
                        )
                else:
                    break

            self.write_log(f"Completion: \n\n{completion}")

    async def stream_complete(
        self, prompt, with_history: List[ChatMessage] = None, **kwargs
    ) -> Generator[Union[Any, List, Dict], None, None]:
        args = {**self.default_args, **kwargs}
        messages = compile_chat_messages(
            self.model,
            with_history,
            self.context_length,
            args["max_tokens"],
            prompt,
            functions=args.get("functions", None),
            system_message=self.system_message,
        )
        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")

        async with self._client_session.post(
            f"{SERVER_URL}/stream_complete",
            json={"messages": messages, **args},
            headers=self.get_headers(),
        ) as resp:
            completion = ""
            if resp.status != 200:
                raise Exception(await resp.text())
            async for line in resp.content.iter_any():
                if line:
                    try:
                        decoded_line = line.decode("utf-8")
                        yield decoded_line
                        completion += decoded_line
                    except:
                        raise Exception(str(line))
            self.write_log(f"Completion: \n\n{completion}")