import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine

from available_models import AVAILABLE_MODELS
from hanlde_form_submit import create_pipeline, on_form_submit


def test_on_form_submit():
    model_name = "gpt2"
    output_length = 10
    prompt = "Answer yes or no, is the sky blue?"
    output = on_form_submit(model_name, output_length, prompt)
    assert output is not None
    assert len(output) > 0
    empty_prompt = ""
    with pytest.raises(ValueError):
        on_form_submit(model_name, output_length, empty_prompt)


def test_create_pipeline():
    pipeline: GroupedSamplingPipeLine = create_pipeline("gpt2")
    assert pipeline is not None
    assert pipeline.model_name == "gpt2"
    assert pipeline.wrapped_model.end_of_sentence_stop is False
    del pipeline


if __name__ == "__main__":
    pytest.main()