from abc import abstractproperty
from typing import List, Optional
import replicate
import concurrent.futures

from ..util.count_tokens import DEFAULT_ARGS, count_tokens
from ...core.main import ChatMessage
from . import LLM


class ReplicateLLM(LLM):
    api_key: str
    model: str = "nateraw/stablecode-completion-alpha-3b-4k:e82ebe958f0a5be6846d1a82041925767edb1d1f162596c643e48fbea332b1bb"
    max_context_length: int = 2048

    _client: replicate.Client = None

    @property
    def name(self):
        return self.model

    @property
    def context_length(self):
        return self.max_context_length

    @property
    def default_args(self):
        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}

    def count_tokens(self, text: str):
        return count_tokens(self.name, text)

    async def start(self):
        self._client = replicate.Client(api_token=self.api_key)

    async def stop(self):
        pass

    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
        def helper():
            output = self._client.run(self.model, input={"message": prompt})
            completion = ''
            for item in output:
                completion += item

            return completion

        with concurrent.futures.ThreadPoolExecutor() as executor:
            future = executor.submit(helper)
            completion = future.result()

        return completion

    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs):
        for item in self._client.run(self.model, input={"message": prompt}):
            yield item

    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
        for item in self._client.run(self.model, input={"message": messages[-1].content}):
            yield {
                "content": item,
                "role": "assistant"
            }