File size: 17,044 Bytes
e77c6a3
 
 
 
 
 
 
 
8e1a13f
 
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1a13f
e77c6a3
8e1a13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77c6a3
8e1a13f
 
 
e77c6a3
 
8e1a13f
 
 
 
 
 
 
 
 
 
 
 
e77c6a3
 
8e1a13f
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1a13f
e77c6a3
 
 
 
 
8e1a13f
e77c6a3
8e1a13f
 
 
 
 
e77c6a3
8e1a13f
 
 
 
 
 
e77c6a3
8e1a13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1a13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77c6a3
8e1a13f
e77c6a3
 
8e1a13f
e77c6a3
8e1a13f
e77c6a3
 
 
 
8e1a13f
e77c6a3
8e1a13f
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1a13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1a13f
e77c6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afabb3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
import gradio as gr
from huggingface_hub import hf_hub_download
import os
from pathlib import Path
import traceback
import glob
from PIL import Image

# Reuse the same load_learned_embed_in_clip and Distance_loss functions
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
    loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
    trained_token = list(loaded_learned_embeds.keys())[0]
    embeds = loaded_learned_embeds[trained_token]

    # Get the expected dimension from the text encoder
    expected_dim = text_encoder.get_input_embeddings().weight.shape[1]
    current_dim = embeds.shape[0]

    # Resize embeddings if dimensions don't match
    if current_dim != expected_dim:
        print(f"Resizing embedding from {current_dim} to {expected_dim}")
        # Option 1: Truncate or pad with zeros
        if current_dim > expected_dim:
            embeds = embeds[:expected_dim]
        else:
            embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0)
        
    # Reshape to match expected dimensions
    embeds = embeds.unsqueeze(0)  # Add batch dimension
    
    # Cast to dtype of text_encoder
    dtype = text_encoder.get_input_embeddings().weight.dtype
    embeds = embeds.to(dtype)

    # Add the token in tokenizer
    token = token if token is not None else trained_token
    num_added_tokens = tokenizer.add_tokens(token)
    
    # Resize the token embeddings
    text_encoder.resize_token_embeddings(len(tokenizer))
    
    # Get the id for the token and assign the embeds
    token_id = tokenizer.convert_tokens_to_ids(token)
    text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0]
    return token

def Distance_loss(images):
    # Ensure we're working with gradients
    if not images.requires_grad:
        images = images.detach().requires_grad_(True)
    
    # Convert to float32 and normalize
    images = images.float() / 2 + 0.5
    
    # Get RGB channels
    red = images[:,0:1]
    green = images[:,1:2]
    blue = images[:,2:3]
    
    # Calculate color distances using L2 norm
    rg_distance = ((red - green) ** 2).mean()
    rb_distance = ((red - blue) ** 2).mean()
    gb_distance = ((green - blue) ** 2).mean()
    
    return (rg_distance + rb_distance + gb_distance) * 100  # Scale up the loss

class StyleGenerator:
    _instance = None
    
    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance
    
    def __init__(self):
        self.pipe = None
        self.style_tokens = []
        self.styles = [
            "ronaldo",
            "canna-lily-flowers102",
            "threestooges",
            "pop_art",
            "bird_style"
        ]
        self.style_names = [
            "Ronaldo",
            "Canna Lily",
            "Three Stooges",
            "Pop Art",
            "Bird Style"
        ]
        self.is_initialized = False
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cpu":
            print("NVIDIA GPU not found. Running on CPU (this will be slower)")

    def initialize_model(self):
        if self.is_initialized:
            return
            
        try:
            print("Initializing Stable Diffusion model...")
            model_id = "runwayml/stable-diffusion-v1-5"
            self.pipe = StableDiffusionPipeline.from_pretrained(
                model_id, 
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                safety_checker=None
            )
            self.pipe = self.pipe.to(self.device)
            
            # Load style embeddings from current directory
            current_dir = Path(__file__).parent
            
            for style, style_name in zip(self.styles, self.style_names):
                style_path = current_dir / f"{style}.bin"
                if not style_path.exists():
                    raise FileNotFoundError(f"Style embedding not found: {style_path}")
                
                print(f"Loading style: {style_name}")
                token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer)
                self.style_tokens.append(token)
                print(f"βœ“ Loaded style: {style_name}")
            
            self.is_initialized = True
            print(f"Model initialization complete! Using device: {self.device}")
            
        except Exception as e:
            print(f"Error during initialization: {str(e)}")
            print(traceback.format_exc())
            raise

    def generate_single_style(self, prompt, selected_style):
        try:
            # Find the index of the selected style
            style_idx = self.style_names.index(self.style_names[selected_style])
            
            # Generate single image with selected style
            styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}"
            
            # Set seed for reproducibility
            generator_seed = 42
            torch.manual_seed(generator_seed)
            if self.device == "cuda":
                torch.cuda.manual_seed(generator_seed)
            
            # Generate base image
            with autocast(self.device):
                base_image = self.pipe(
                    styled_prompt,
                    num_inference_steps=50,
                    guidance_scale=7.5,
                    generator=torch.Generator(self.device).manual_seed(generator_seed)
                ).images[0]
            
            # Generate same image with loss
            with autocast(self.device):
                loss_image = self.pipe(
                    styled_prompt,
                    num_inference_steps=50,
                    guidance_scale=7.5,
                    callback=self.callback_fn,
                    callback_steps=5,
                    generator=torch.Generator(self.device).manual_seed(generator_seed)
                ).images[0]
            
            return base_image, loss_image
            
        except Exception as e:
            print(f"Error in generate_single_style: {e}")
            raise

    def callback_fn(self, i, t, latents):
        if i % 5 == 0:  # Apply loss every 5 steps
            try:
                # Create a copy that requires gradients
                latents_copy = latents.detach().clone()
                latents_copy.requires_grad_(True)
                
                # Compute loss
                loss = Distance_loss(latents_copy)
                
                # Compute gradients
                if loss.requires_grad:
                    grads = torch.autograd.grad(
                        outputs=loss,
                        inputs=latents_copy,
                        allow_unused=True,
                        retain_graph=False
                    )[0]
                    
                    if grads is not None:
                        # Apply gradients to original latents
                        return latents - 0.1 * grads.detach()
            
            except Exception as e:
                print(f"Error in callback: {e}")
            
        return latents

def generate_single_style(prompt, selected_style):
    try:
        generator = StyleGenerator.get_instance()
        if not generator.is_initialized:
            generator.initialize_model()
        
        base_image, loss_image = generator.generate_single_style(prompt, selected_style)
        
        return [
            gr.update(visible=False),  # error_message
            base_image,                # original_image
            loss_image                 # loss_image
        ]
    except Exception as e:
        print(f"Error in generate_single_style: {e}")
        return [
            gr.update(value=f"Error: {str(e)}", visible=True),  # error_message
            None,  # original_image
            None   # loss_image
        ]

# Add at the start of your script
def debug_image_paths():
    output_dir = Path("Outputs")
    enhanced_dir = output_dir / "Color_Enhanced"
    print(f"\nChecking image paths:")
    print(f"Current working directory: {Path.cwd()}")
    print(f"Looking for images in: {enhanced_dir.absolute()}")
    
    if enhanced_dir.exists():
        print("\nFound files:")
        for file in enhanced_dir.glob("*.webp"):
            print(f"- {file.name}")
    else:
        print("\nDirectory not found!")

# Call this function before creating the interface
debug_image_paths()

# Create a more beautiful interface with custom styling
with gr.Blocks(css="""
    .gradio-container {
        background-color: #1f2937 !important;
    }
    .dark-theme {
        background-color: #111827;
        border-radius: 10px;
        padding: 20px;
        margin: 10px;
        border: 1px solid #374151;
        color: #f3f4f6;
    }
    /* Enhanced Tab Styling */
    .tabs.svelte-710i53 {
        margin-bottom: 0 !important;
    }
    .tab-nav.svelte-710i53 {
        background: transparent !important;
        border: none !important;
        padding: 12px 24px !important;
        margin: 0 2px !important;
        color: #9CA3AF !important;
        font-weight: 500 !important;
        transition: all 0.2s ease !important;
        border-bottom: 2px solid transparent !important;
    }
    .tab-nav.svelte-710i53.selected {
        background: transparent !important;
        color: #F3F4F6 !important;
        border-bottom: 2px solid #6366F1 !important;
    }
    .tab-nav.svelte-710i53:hover {
        color: #F3F4F6 !important;
        border-bottom: 2px solid #4F46E5 !important;
    }
""") as iface:
    # Header section
    gr.Markdown(
        """
        <div class="dark-theme" style="text-align: center;">
        # 🎨 AI Style Transfer Studio
        ### Transform your ideas into artistic masterpieces
        </div>
        """
    )

    # Controls section
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 🎯 Controls")
            
            prompt = gr.Textbox(
                label="What would you like to create?",
                placeholder="e.g., a soccer player celebrating a goal",
                lines=3
            )
            
            style_radio = gr.Radio(
                choices=[
                    "Ronaldo Style",
                    "Canna Lily",
                    "Three Stooges",
                    "Pop Art",
                    "Bird Style"
                ],
                label="Choose Your Style",
                value="Ronaldo Style",
                type="index"
            )
            
            generate_btn = gr.Button(
                "πŸš€ Generate Artwork", 
                variant="primary",
                size="lg"
            )
            
            error_message = gr.Markdown(visible=False)
            style_description = gr.Markdown()

    # Generated Images
    with gr.Row():
        with gr.Column():
            original_image = gr.Image(
                label="Original Style",
                show_label=True,
                height=300
            )
        with gr.Column():
            loss_image = gr.Image(
                label="Color Enhanced",
                show_label=True,
                height=300
            )

    # Example Gallery
    gr.Markdown(
        """
        <div class="dark-theme">
        ## πŸŽ† Example Gallery
        Compare original and enhanced versions for each style:
        </div>
        """
    )

    # Example Images
    with gr.Row():
        try:
            output_dir = Path("Outputs")
            original_dir = output_dir
            enhanced_dir = output_dir / "Color_Enhanced"

            if enhanced_dir.exists():
                original_images = {
                    Path(f).stem.split('_example')[0]: f 
                    for f in original_dir.glob("*.webp") 
                    if '_example' in f.name
                }
                enhanced_images = {
                    Path(f).stem.split('_example')[0]: f 
                    for f in enhanced_dir.glob("*.webp") 
                    if '_example' in f.name
                }

                styles = [
                    ("ronaldo", "Ronaldo Style"),
                    ("canna_lily", "Canna Lily"),
                    ("three_stooges", "Three Stooges"),
                    ("pop_art", "Pop Art"),
                    ("bird_style", "Bird Style")
                ]

                # Create a grid of all styles
                for style_key, style_name in styles:
                    if style_key in original_images and style_key in enhanced_images:
                        with gr.Row():
                            gr.Markdown(f"### {style_name}")
                        with gr.Row():
                            with gr.Column(scale=1):
                                gr.Image(
                                    value=str(original_images[style_key]),
                                    label="Original",
                                    show_label=True,
                                    height=180
                                )
                            with gr.Column(scale=1):
                                gr.Image(
                                    value=str(enhanced_images[style_key]),
                                    label="Color Enhanced",
                                    show_label=True,
                                    height=180
                                )
                        # Add a small spacing between styles
                        gr.Markdown("<div style='margin: 10px 0;'></div>")

        except Exception as e:
            print(f"Error in example gallery: {e}")
            gr.Markdown(f"Error loading example gallery: {str(e)}")

    # Info section
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                <div class="dark-theme">
                ## 🎨 Style Guide
                
                | Style | Best For |
                |-------|----------|
                | **Ronaldo Style** | Dynamic sports scenes, action shots, celebrations |
                | **Canna Lily** | Natural scenes, floral compositions, garden imagery |
                | **Three Stooges** | Comedy, humor, expressive character portraits |
                | **Pop Art** | Vibrant artwork, bold colors, stylized designs |
                | **Bird Style** | Wildlife, nature scenes, peaceful landscapes |
                
                *Choose the style that best matches your creative vision*
                </div>
                """
            )
        with gr.Column():
            gr.Markdown(
                """
                <div class="dark-theme">
                ## πŸ” Color Enhancement Technology
                
                Our advanced color processing uses distance loss to enhance your images:
                
                ### 🌈 Color Dynamics
                - **Vibrancy**: Intensifies colors naturally
                - **Contrast**: Improves depth and definition
                - **Balance**: Optimizes color relationships
                
                ### 🎨 Technical Features
                - **Channel Separation**: RGB optimization
                - **Loss Function**: Mathematical color enhancement
                - **Real-time Processing**: Dynamic adjustments
                
                ### ✨ Benefits
                - Richer, more vivid colors
                - Clearer color boundaries
                - Reduced color muddiness
                - Enhanced artistic impact
                
                <small>*Our color distance loss technology mathematically optimizes RGB channel relationships*</small>
                </div>
                """
            )

    # Update style description on change
    def update_style_description(style_idx):
        descriptions = [
            "Perfect for capturing dynamic sports moments and celebrations",
            "Ideal for creating beautiful natural and floral compositions",
            "Great for adding humor and expressiveness to your scenes",
            "Transform your ideas into vibrant pop art masterpieces",
            "Specialized in capturing the beauty of nature and wildlife"
        ]
        styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"]
        return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}"

    style_radio.change(
        fn=update_style_description,
        inputs=style_radio,
        outputs=style_description
    )

    generate_btn.click(
        fn=generate_single_style,
        inputs=[prompt, style_radio],
        outputs=[error_message, original_image, loss_image]
    )

# Launch the app
if __name__ == "__main__":
    iface.launch(
        share=True,
        show_error=True
    )