from unittest import TestCase

from hypothesis import given
from hypothesis import strategies as st
from transformers import AutoTokenizer

from trlx.pipeline.offline_pipeline import DialogMessage, tokenize_dialogue


class TestTokenizeDialog(TestCase):
    def setUp(self):
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")

    def test_tokenize_dialogue_truncation_basic(self):
        dialogue = ["this will be truncated", "."]
        self.tokenizer.truncation_side = "left"

        dialog = tokenize_dialogue(dialogue, self.tokenizer, max_length=2)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog
        assert len(user_dm.tokens) == 1
        assert len(bot_dm.tokens) == 1
        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=(self.tokenizer.eos_token_id,))

    @given(st.lists(st.text(), max_size=32))
    def test_tokenize_dialogue_single_turn(self, response_words):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog

        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response)

    @given(st.lists(st.text(), max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_single_turn_truncation_right(self, response_words, max_length):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        self.tokenizer.truncation_side = "right"
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog

        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[: max_length - 1])

        all_tokens = sum((dm.tokens for dm in dialog), ())
        assert len(all_tokens) <= max_length

    @given(st.lists(st.text(), max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_single_turn_truncation_left(self, response_words, max_length):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        self.tokenizer.truncation_side = "left"
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response += (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length)

        # whether or not truncation has happened, user BOS prompt should be present
        assert len(dialog) == 2
        user_dm, bot_dm = dialog
        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))

        if len(tokenized_response) < max_length:
            assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response)
        else:
            assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[-max_length + 1 :])

        all_tokens = sum((dm.tokens for dm in dialog), ())
        assert len(all_tokens) <= max_length

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32))
    def test_tokenize_dialogue_multi_turn(self, user_response_pairs):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer)

        dm_convo = [DialogMessage(is_output=i % 2 == 1, tokens=tokens) for i, tokens in enumerate(tokenized_flat_convo)]
        nonempty_dm_convo = [dm for dm in dm_convo if dm.tokens]
        if nonempty_dm_convo[0].is_output:
            nonempty_dm_convo.insert(0, DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)))

        assert dialog == nonempty_dm_convo

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_multi_turn_truncation_right(self, user_response_pairs, max_length):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        self.tokenizer.truncation_side = "right"
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length)

        all_tokens = sum((dm.tokens for dm in dialog), ())
        should_be_tokens = sum(tokenized_flat_convo, ())[:max_length]
        if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
            should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[: max_length - 1])

        assert all_tokens == should_be_tokens
        assert len(all_tokens) <= max_length

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_multi_turn_truncation_left(self, user_response_pairs, max_length):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        self.tokenizer.truncation_side = "left"
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length)

        all_tokens = sum((dm.tokens for dm in dialog), ())
        should_be_tokens = sum(tokenized_flat_convo, ())[-max_length:]
        if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
            should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[-max_length + 1 :])

        assert all_tokens == should_be_tokens
        assert len(all_tokens) <= max_length