File size: 5,183 Bytes
4c1e086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
import torch
from torch import nn
from diffusers import DDPMScheduler, UNet2DModel
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Reuse your existing model code
class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=3, class_emb_size=12):
        super().__init__()
        self.class_emb = nn.Embedding(num_classes, class_emb_size)
        self.model = UNet2DModel(
            sample_size=64,
            in_channels=3 + class_emb_size,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape
        class_cond = self.class_emb(class_labels)
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        net_input = torch.cat((x, class_cond), 1)
        return self.model(net_input, t).sample

@st.cache_resource
def load_model(model_path):
    """Load the model with caching to avoid reloading"""
    device = 'cpu'  # For deployment, we'll use CPU
    net = ClassConditionedUnet().to(device)
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
    checkpoint = torch.load(model_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    return net, noise_scheduler

def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1):
    """Generate faces with mixed ethnic features"""
    device = next(net.parameters()).device
    net.eval()
    with torch.no_grad():
        x = torch.randn(num_images, 3, 64, 64).to(device)
        
        # Get embeddings for all classes
        emb_asian = net.class_emb(torch.zeros(num_images).long().to(device))
        emb_indian = net.class_emb(torch.ones(num_images).long().to(device))
        emb_european = net.class_emb(torch.full((num_images,), 2).to(device))
        
        progress_bar = st.progress(0)
        for idx, t in enumerate(noise_scheduler.timesteps):
            # Update progress bar
            progress_bar.progress(idx / len(noise_scheduler.timesteps))
            
            # Mix embeddings according to weights
            mixed_emb = (
                mix_weights[0] * emb_asian +
                mix_weights[1] * emb_indian +
                mix_weights[2] * emb_european
            )
            
            # Override embedding layer temporarily
            original_forward = net.class_emb.forward
            net.class_emb.forward = lambda _: mixed_emb
            
            residual = net(x, t, torch.zeros(num_images).long().to(device))
            x = noise_scheduler.step(residual, t, x).prev_sample
            
            # Restore original embedding layer
            net.class_emb.forward = original_forward
        
        progress_bar.progress(1.0)
    
    x = (x.clamp(-1, 1) + 1) / 2
    return x

def main():
    st.title("AI Face Generator with Ethnic Features Mixing")
    
    # Load model
    try:
        net, noise_scheduler = load_model('final_model/final_diffusion_model.pt')
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return
    
    # Create sliders for ethnicity percentages
    st.subheader("Adjust Ethnicity Mix")
    col1, col2, col3 = st.columns(3)
    
    with col1:
        asian_pct = st.slider("Asian Features %", 0, 100, 33, 1)
    with col2:
        indian_pct = st.slider("Indian Features %", 0, 100, 33, 1)
    with col3:
        european_pct = st.slider("European Features %", 0, 100, 34, 1)
    
    # Calculate total and normalize if needed
    total = asian_pct + indian_pct + european_pct
    if total == 0:
        st.warning("Total percentage cannot be 0%. Please adjust the sliders.")
        return
    
    # Normalize weights to sum to 1
    weights = [asian_pct/total, indian_pct/total, european_pct/total]
    
    # Display current mix
    st.write("Current mix (normalized):")
    st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}")
    
    # Generate button
    if st.button("Generate Face"):
        try:
            with st.spinner("Generating face..."):
                # Generate the image
                generated_images = generate_mixed_faces(net, noise_scheduler, weights)
                
                # Convert to numpy and display
                img = generated_images[0].permute(1, 2, 0).cpu().numpy()
                st.image(img, caption="Generated Face", use_column_width=True)
                
        except Exception as e:
            st.error(f"Error generating image: {str(e)}")

if __name__ == "__main__":
    main()