import pytest

from tests.utils import wrap_test_forked, get_llama
from src.enums import DocumentSubset


@wrap_test_forked
def test_cli(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in our solar system" in all_generations[0]


@pytest.mark.parametrize("base_model", ['gptj', 'gpt4all_llama'])
@wrap_test_forked
def test_cli_langchain(base_model, monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model=base_model, cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    # no sources in output now
    # assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "what the cat is doing" in all_generations[0] or \
           "question about a cat" in all_generations[0] or \
           "The prompt asks for an answer to a question" in all_generations[0] or \
           "The prompt asks what the cat in the scenario is doing" in all_generations[0] or \
           "The prompt asks why H2O.ai" in all_generations[0] or \
           "cat is sitting on a window" in all_generations[0] or \
           "cat is sitting" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        model_path_llama=full_path,
                                        prompt_type=prompt_type,
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in str(all_sources[0])
    assert "the cat is sitting" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "The cat is likely relaxing and enjoying" in all_generations[0] or \
           "cat in the image is" in all_generations[0] or \
           "cat is sitting on a window sill" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    query = "Who are you?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    langchain_mode = 'Disabled'
    all_generations, all_sources = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                                        langchain_mode=langchain_mode,
                                        prompt_type=prompt_type,
                                        model_path_llama=full_path,
                                        user_path=None,
                                        langchain_modes=[langchain_mode],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
           "how can I assist" in all_generations[0] or \
           "am a virtual assistant" in all_generations[0] or \
           "My name is John." in all_generations[0] or \
           "I am a student" in all_generations[0] or \
           "I'm LLaMA" in all_generations[0] or \
           "Hello! I'm just an AI assistant" in all_generations[0]


@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False,
                                        score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in the Solar System".lower() in all_generations[0].lower() or \
           "The Earth is the third planet".lower() in all_generations[0].lower()


@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations, all_sources = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
                                        cli=True, cli_loop=False, score_model='None',
                                        langchain_mode='UserData',
                                        user_path=user_path,
                                        langchain_modes=['UserData', 'MyData'],
                                        document_subset=DocumentSubset.Relevant.name,
                                        verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "cat is sitting" in all_generations[0]