import torch
from Bio.PDB.Structure import Structure

from hexviz.attention import (
    ModelType,
    get_attention,
    get_sequences,
    get_structure,
    unidirectional_avg_filtered,
)


def test_get_structure():
    pdb_id = "2I62"
    structure = get_structure(pdb_id)

    assert structure is not None
    assert isinstance(structure, Structure)


def test_get_sequences():
    pdb_id = "1AKE"
    structure = get_structure(pdb_id)

    sequences = get_sequences(structure)

    assert sequences is not None
    assert len(sequences) == 2

    A, B = sequences
    assert A[:3] == ["M", "R", "I"]


def test_get_attention_zymctrl():

    result = get_attention("GGG", model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36, 16, 3, 3])


def test_get_attention_zymctrl_long_chain():
    structure = get_structure(pdb_code="6A5J")  # 13 residues long

    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)

    assert result is not None
    assert result.shape == torch.Size([36, 16, 13, 13])


def test_get_attention_tape():
    structure = get_structure(pdb_code="6A5J")  # 13 residues long
    sequences = get_sequences(structure)

    result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)

    assert result is not None
    assert result.shape == torch.Size([12, 12, 13, 13])


def test_get_attention_prot_bert():

    result = get_attention("GGG", model_type=ModelType.PROT_BERT)

    assert result is not None
    assert result.shape == torch.Size([30, 16, 3, 3])


def test_get_unidirection_avg_filtered():
    # 1 head, 1 layer, 4 residues long attention tensor
    attention = torch.tensor(
        [[[[1, 2, 3, 4], [2, 5, 6, 7], [3, 6, 8, 9], [4, 7, 9, 11]]]], dtype=torch.float32
    )

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert result is not None
    assert len(result) == 10

    attention = torch.tensor([[[[1, 2, 3], [2, 5, 6], [4, 7, 91]]]], dtype=torch.float32)

    result = unidirectional_avg_filtered(attention, 0, 0, 0)

    assert len(result) == 6