import einops
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torch.backends.cuda import SDPBackend

from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer


def benchmark_attn():
    # Lets define a helpful benchmarking function:
    # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
        t0 = benchmark.Timer(
            stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
        )
        return t0.blocked_autorange().mean * 1e6

    # Lets define the hyper-parameters of our input
    batch_size = 32
    max_sequence_len = 1024
    num_heads = 32
    embed_dimension = 32

    dtype = torch.float16

    query = torch.rand(
        batch_size,
        num_heads,
        max_sequence_len,
        embed_dimension,
        device=device,
        dtype=dtype,
    )
    key = torch.rand(
        batch_size,
        num_heads,
        max_sequence_len,
        embed_dimension,
        device=device,
        dtype=dtype,
    )
    value = torch.rand(
        batch_size,
        num_heads,
        max_sequence_len,
        embed_dimension,
        device=device,
        dtype=dtype,
    )

    print(f"q/k/v shape:", query.shape, key.shape, value.shape)

    # Lets explore the speed of each of the 3 implementations
    from torch.backends.cuda import SDPBackend, sdp_kernel

    # Helpful arguments mapper
    backend_map = {
        SDPBackend.MATH: {
            "enable_math": True,
            "enable_flash": False,
            "enable_mem_efficient": False,
        },
        SDPBackend.FLASH_ATTENTION: {
            "enable_math": False,
            "enable_flash": True,
            "enable_mem_efficient": False,
        },
        SDPBackend.EFFICIENT_ATTENTION: {
            "enable_math": False,
            "enable_flash": False,
            "enable_mem_efficient": True,
        },
    }

    from torch.profiler import ProfilerActivity, profile, record_function

    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]

    print(
        f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
    )
    with profile(
        activities=activities, record_shapes=False, profile_memory=True
    ) as prof:
        with record_function("Default detailed stats"):
            for _ in range(25):
                o = F.scaled_dot_product_attention(query, key, value)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    print(
        f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
    )
    with sdp_kernel(**backend_map[SDPBackend.MATH]):
        with profile(
            activities=activities, record_shapes=False, profile_memory=True
        ) as prof:
            with record_function("Math implmentation stats"):
                for _ in range(25):
                    o = F.scaled_dot_product_attention(query, key, value)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
        try:
            print(
                f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
            )
        except RuntimeError:
            print("FlashAttention is not supported. See warnings for reasons.")
        with profile(
            activities=activities, record_shapes=False, profile_memory=True
        ) as prof:
            with record_function("FlashAttention stats"):
                for _ in range(25):
                    o = F.scaled_dot_product_attention(query, key, value)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
        try:
            print(
                f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
            )
        except RuntimeError:
            print("EfficientAttention is not supported. See warnings for reasons.")
        with profile(
            activities=activities, record_shapes=False, profile_memory=True
        ) as prof:
            with record_function("EfficientAttention stats"):
                for _ in range(25):
                    o = F.scaled_dot_product_attention(query, key, value)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


def run_model(model, x, context):
    return model(x, context)


def benchmark_transformer_blocks():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    import torch.utils.benchmark as benchmark

    def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
        t0 = benchmark.Timer(
            stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
        )
        return t0.blocked_autorange().mean * 1e6

    checkpoint = True
    compile = False

    batch_size = 32
    h, w = 64, 64
    context_len = 77
    embed_dimension = 1024
    context_dim = 1024
    d_head = 64

    transformer_depth = 4

    n_heads = embed_dimension // d_head

    dtype = torch.float16

    model_native = SpatialTransformer(
        embed_dimension,
        n_heads,
        d_head,
        context_dim=context_dim,
        use_linear=True,
        use_checkpoint=checkpoint,
        attn_type="softmax",
        depth=transformer_depth,
        sdp_backend=SDPBackend.FLASH_ATTENTION,
    ).to(device)
    model_efficient_attn = SpatialTransformer(
        embed_dimension,
        n_heads,
        d_head,
        context_dim=context_dim,
        use_linear=True,
        depth=transformer_depth,
        use_checkpoint=checkpoint,
        attn_type="softmax-xformers",
    ).to(device)
    if not checkpoint and compile:
        print("compiling models")
        model_native = torch.compile(model_native)
        model_efficient_attn = torch.compile(model_efficient_attn)

    x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
    c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)

    from torch.profiler import ProfilerActivity, profile, record_function

    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]

    with torch.autocast("cuda"):
        print(
            f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
        )
        print(
            f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
        )

        print(75 * "+")
        print("NATIVE")
        print(75 * "+")
        torch.cuda.reset_peak_memory_stats()
        with profile(
            activities=activities, record_shapes=False, profile_memory=True
        ) as prof:
            with record_function("NativeAttention stats"):
                for _ in range(25):
                    model_native(x, c)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
        print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")

        print(75 * "+")
        print("Xformers")
        print(75 * "+")
        torch.cuda.reset_peak_memory_stats()
        with profile(
            activities=activities, record_shapes=False, profile_memory=True
        ) as prof:
            with record_function("xformers stats"):
                for _ in range(25):
                    model_efficient_attn(x, c)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
        print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")


def test01():
    # conv1x1 vs linear
    from sgm.util import count_params

    conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
    print(count_params(conv))
    linear = torch.nn.Linear(3, 32).cuda()
    print(count_params(linear))

    print(conv.weight.shape)

    # use same initialization
    linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
    linear.bias = torch.nn.Parameter(conv.bias)

    print(linear.weight.shape)

    x = torch.randn(11, 3, 64, 64).cuda()

    xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
    print(xr.shape)
    out_linear = linear(xr)
    print(out_linear.mean(), out_linear.shape)

    out_conv = conv(x)
    print(out_conv.mean(), out_conv.shape)
    print("done with test01.\n")


def test02():
    # try cosine flash attention
    import time

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    print("testing cosine flash attention...")
    DIM = 1024
    SEQLEN = 4096
    BS = 16

    print(" softmax (vanilla) first...")
    model = BasicTransformerBlock(
        dim=DIM,
        n_heads=16,
        d_head=64,
        dropout=0.0,
        context_dim=None,
        attn_mode="softmax",
    ).cuda()
    try:
        x = torch.randn(BS, SEQLEN, DIM).cuda()
        tic = time.time()
        y = model(x)
        toc = time.time()
        print(y.shape, toc - tic)
    except RuntimeError as e:
        # likely oom
        print(str(e))

    print("\n now flash-cosine...")
    model = BasicTransformerBlock(
        dim=DIM,
        n_heads=16,
        d_head=64,
        dropout=0.0,
        context_dim=None,
        attn_mode="flash-cosine",
    ).cuda()
    x = torch.randn(BS, SEQLEN, DIM).cuda()
    tic = time.time()
    y = model(x)
    toc = time.time()
    print(y.shape, toc - tic)
    print("done with test02.\n")


if __name__ == "__main__":
    # test01()
    # test02()
    # test03()

    # benchmark_attn()
    benchmark_transformer_blocks()

    print("done.")