from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
import io
import PIL.Image

def calculate_weight_diff(base_weight, chat_weight):
    return torch.abs(base_weight - chat_weight).mean().item()

def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
    layer_diffs = []
    layers = zip(base_model.model.layers, chat_model.model.layers)
    
    if load_one_at_a_time:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

            base_layer, chat_layer = None, None
            del base_layer, chat_layer
    else:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

    return layer_diffs

def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
    num_layers = len(layer_diffs)
    num_components = len(layer_diffs[0])
    
    # Dynamically adjust figure size based on number of layers
    height = max(8, num_layers / 8)  # Minimum height of 8, scales up for more layers
    width = max(24, num_components * 3)  # Minimum width of 24, scales with components
    
    # Create figure with subplots arranged in 2 rows if there are many components
    if num_components > 6:
        nrows = 2
        ncols = (num_components + 1) // 2
        fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
        axs = axs.flatten()
    else:
        nrows = 1
        ncols = num_components
        fig, axs = plt.subplots(1, num_components, figsize=(width, height))

    fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)

    # Adjust font sizes based on number of layers
    tick_font_size = max(6, min(10, 300 / num_layers))
    annot_font_size = max(6, min(10, 200 / num_layers))

    for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
        component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
        sns.heatmap(component_diffs, 
                    annot=True, 
                    fmt=".9f", 
                    cmap="YlGnBu", 
                    ax=axs[i], 
                    cbar=False,
                    annot_kws={'size': annot_font_size})
        
        axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
        axs[i].set_xlabel("Difference", fontsize=tick_font_size)
        axs[i].set_ylabel("Layer", fontsize=tick_font_size)
        axs[i].set_xticks([])
        axs[i].set_yticks(range(num_layers))
        axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
        axs[i].invert_yaxis()

    # Remove any empty subplots if using 2 rows
    if num_components > 6:
        for j in range(i + 1, len(axs)):
            fig.delaxes(axs[j])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to prevent overlap
    
    # Convert plot to image
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
    buf.seek(0)
    plt.close(fig)  # Close the figure to free memory
    return PIL.Image.open(buf)

def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
    chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)

    layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
    return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)

if __name__ == "__main__":
    iface = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Textbox(label="Base Model Name", lines=2),
            gr.Textbox(label="Chat Model Name", lines=2),
            gr.Textbox(label="Hugging Face Token", type="password", lines=2),
            gr.Checkbox(label="Load one layer at a time")
        ],
        outputs=gr.Image(type="pil", label="Weight Differences Visualization"),
        title="Model Weight Difference Visualizer",
        cache_examples=False
    )
    
    iface.launch(share=False, server_port=7860)