import streamlit as st
import torch
from torch import nn
from diffusers import DDPMScheduler, UNet2DModel
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Reuse your existing model code
class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=3, class_emb_size=12):
        super().__init__()
        self.class_emb = nn.Embedding(num_classes, class_emb_size)
        self.model = UNet2DModel(
            sample_size=64,
            in_channels=3 + class_emb_size,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape
        class_cond = self.class_emb(class_labels)
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        net_input = torch.cat((x, class_cond), 1)
        return self.model(net_input, t).sample

@st.cache_resource
def load_model(model_path):
    """Load the model with caching to avoid reloading"""
    device = 'cpu'  # For deployment, we'll use CPU
    net = ClassConditionedUnet().to(device)
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
    checkpoint = torch.load(model_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    return net, noise_scheduler

def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1):
    """Generate faces with mixed ethnic features"""
    device = next(net.parameters()).device
    net.eval()
    with torch.no_grad():
        x = torch.randn(num_images, 3, 64, 64).to(device)
        
        # Get embeddings for all classes
        emb_asian = net.class_emb(torch.zeros(num_images).long().to(device))
        emb_indian = net.class_emb(torch.ones(num_images).long().to(device))
        emb_european = net.class_emb(torch.full((num_images,), 2).to(device))
        
        progress_bar = st.progress(0)
        for idx, t in enumerate(noise_scheduler.timesteps):
            # Update progress bar
            progress_bar.progress(idx / len(noise_scheduler.timesteps))
            
            # Mix embeddings according to weights
            mixed_emb = (
                mix_weights[0] * emb_asian +
                mix_weights[1] * emb_indian +
                mix_weights[2] * emb_european
            )
            
            # Override embedding layer temporarily
            original_forward = net.class_emb.forward
            net.class_emb.forward = lambda _: mixed_emb
            
            residual = net(x, t, torch.zeros(num_images).long().to(device))
            x = noise_scheduler.step(residual, t, x).prev_sample
            
            # Restore original embedding layer
            net.class_emb.forward = original_forward
        
        progress_bar.progress(1.0)
    
    x = (x.clamp(-1, 1) + 1) / 2
    return x

def main():
    st.title("AI Face Generator with Ethnic Features Mixing")
    
    # Load model
    try:
        net, noise_scheduler = load_model('final_model/final_diffusion_model.pt')
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return
    
    # Create sliders for ethnicity percentages
    st.subheader("Adjust Ethnicity Mix")
    col1, col2, col3 = st.columns(3)
    
    with col1:
        asian_pct = st.slider("Asian Features %", 0, 100, 33, 1)
    with col2:
        indian_pct = st.slider("Indian Features %", 0, 100, 33, 1)
    with col3:
        european_pct = st.slider("European Features %", 0, 100, 34, 1)
    
    # Calculate total and normalize if needed
    total = asian_pct + indian_pct + european_pct
    if total == 0:
        st.warning("Total percentage cannot be 0%. Please adjust the sliders.")
        return
    
    # Normalize weights to sum to 1
    weights = [asian_pct/total, indian_pct/total, european_pct/total]
    
    # Display current mix
    st.write("Current mix (normalized):")
    st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}")
    
    # Generate button
    if st.button("Generate Face"):
        try:
            with st.spinner("Generating face..."):
                # Generate the image
                generated_images = generate_mixed_faces(net, noise_scheduler, weights)
                
                # Convert to numpy and display
                img = generated_images[0].permute(1, 2, 0).cpu().numpy()
                st.image(img, caption="Generated Face", use_column_width=True)
                
        except Exception as e:
            st.error(f"Error generating image: {str(e)}")

if __name__ == "__main__":
    main()