import torch from diffusers import StableDiffusionPipeline from torch import autocast import gradio as gr from huggingface_hub import hf_hub_download import os from pathlib import Path import traceback import glob from PIL import Image # Reuse the same load_learned_embed_in_clip and Distance_loss functions def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") trained_token = list(loaded_learned_embeds.keys())[0] embeds = loaded_learned_embeds[trained_token] # Get the expected dimension from the text encoder expected_dim = text_encoder.get_input_embeddings().weight.shape[1] current_dim = embeds.shape[0] # Resize embeddings if dimensions don't match if current_dim != expected_dim: print(f"Resizing embedding from {current_dim} to {expected_dim}") # Option 1: Truncate or pad with zeros if current_dim > expected_dim: embeds = embeds[:expected_dim] else: embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0) # Reshape to match expected dimensions embeds = embeds.unsqueeze(0) # Add batch dimension # Cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype embeds = embeds.to(dtype) # Add the token in tokenizer token = token if token is not None else trained_token num_added_tokens = tokenizer.add_tokens(token) # Resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # Get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0] return token def Distance_loss(images): # Ensure we're working with gradients if not images.requires_grad: images = images.detach().requires_grad_(True) # Convert to float32 and normalize images = images.float() / 2 + 0.5 # Get RGB channels red = images[:,0:1] green = images[:,1:2] blue = images[:,2:3] # Calculate color distances using L2 norm rg_distance = ((red - green) ** 2).mean() rb_distance = ((red - blue) ** 2).mean() gb_distance = ((green - blue) ** 2).mean() return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss class StyleGenerator: _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): self.pipe = None self.style_tokens = [] self.styles = [ "ronaldo", "canna-lily-flowers102", "threestooges", "pop_art", "bird_style" ] self.style_names = [ "Ronaldo", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style" ] self.is_initialized = False self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.device == "cpu": print("NVIDIA GPU not found. Running on CPU (this will be slower)") def initialize_model(self): if self.is_initialized: return try: print("Initializing Stable Diffusion model...") model_id = "runwayml/stable-diffusion-v1-5" self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None ) self.pipe = self.pipe.to(self.device) # Load style embeddings from current directory current_dir = Path(__file__).parent for style, style_name in zip(self.styles, self.style_names): style_path = current_dir / f"{style}.bin" if not style_path.exists(): raise FileNotFoundError(f"Style embedding not found: {style_path}") print(f"Loading style: {style_name}") token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer) self.style_tokens.append(token) print(f"✓ Loaded style: {style_name}") self.is_initialized = True print(f"Model initialization complete! Using device: {self.device}") except Exception as e: print(f"Error during initialization: {str(e)}") print(traceback.format_exc()) raise def generate_single_style(self, prompt, selected_style): try: # Find the index of the selected style style_idx = self.style_names.index(self.style_names[selected_style]) # Generate single image with selected style styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}" # Set seed for reproducibility generator_seed = 42 torch.manual_seed(generator_seed) if self.device == "cuda": torch.cuda.manual_seed(generator_seed) # Generate base image with autocast(self.device): base_image = self.pipe( styled_prompt, num_inference_steps=50, guidance_scale=7.5, generator=torch.Generator(self.device).manual_seed(generator_seed) ).images[0] # Generate same image with loss with autocast(self.device): loss_image = self.pipe( styled_prompt, num_inference_steps=50, guidance_scale=7.5, callback=self.callback_fn, callback_steps=5, generator=torch.Generator(self.device).manual_seed(generator_seed) ).images[0] return base_image, loss_image except Exception as e: print(f"Error in generate_single_style: {e}") raise def callback_fn(self, i, t, latents): if i % 5 == 0: # Apply loss every 5 steps try: # Create a copy that requires gradients latents_copy = latents.detach().clone() latents_copy.requires_grad_(True) # Compute loss loss = Distance_loss(latents_copy) # Compute gradients if loss.requires_grad: grads = torch.autograd.grad( outputs=loss, inputs=latents_copy, allow_unused=True, retain_graph=False )[0] if grads is not None: # Apply gradients to original latents return latents - 0.1 * grads.detach() except Exception as e: print(f"Error in callback: {e}") return latents def generate_single_style(prompt, selected_style): try: generator = StyleGenerator.get_instance() if not generator.is_initialized: generator.initialize_model() base_image, loss_image = generator.generate_single_style(prompt, selected_style) return [ gr.update(visible=False), # error_message base_image, # original_image loss_image # loss_image ] except Exception as e: print(f"Error in generate_single_style: {e}") return [ gr.update(value=f"Error: {str(e)}", visible=True), # error_message None, # original_image None # loss_image ] # Add at the start of your script def debug_image_paths(): output_dir = Path("Outputs") enhanced_dir = output_dir / "Color_Enhanced" print(f"\nChecking image paths:") print(f"Current working directory: {Path.cwd()}") print(f"Looking for images in: {enhanced_dir.absolute()}") if enhanced_dir.exists(): print("\nFound files:") for file in enhanced_dir.glob("*.webp"): print(f"- {file.name}") else: print("\nDirectory not found!") # Call this function before creating the interface debug_image_paths() # Create a more beautiful interface with custom styling with gr.Blocks(css=""" .gradio-container { background-color: #1f2937 !important; } .dark-theme { background-color: #111827; border-radius: 10px; padding: 20px; margin: 10px; border: 1px solid #374151; color: #f3f4f6; } /* Enhanced Tab Styling */ .tabs.svelte-710i53 { margin-bottom: 0 !important; } .tab-nav.svelte-710i53 { background: transparent !important; border: none !important; padding: 12px 24px !important; margin: 0 2px !important; color: #9CA3AF !important; font-weight: 500 !important; transition: all 0.2s ease !important; border-bottom: 2px solid transparent !important; } .tab-nav.svelte-710i53.selected { background: transparent !important; color: #F3F4F6 !important; border-bottom: 2px solid #6366F1 !important; } .tab-nav.svelte-710i53:hover { color: #F3F4F6 !important; border-bottom: 2px solid #4F46E5 !important; } """) as iface: # Header section gr.Markdown( """
# 🎨 AI Style Transfer Studio ### Transform your ideas into artistic masterpieces
""" ) # Controls section with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 🎯 Controls") prompt = gr.Textbox( label="What would you like to create?", placeholder="e.g., a soccer player celebrating a goal", lines=3 ) style_radio = gr.Radio( choices=[ "Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style" ], label="Choose Your Style", value="Ronaldo Style", type="index" ) generate_btn = gr.Button( "🚀 Generate Artwork", variant="primary", size="lg" ) error_message = gr.Markdown(visible=False) style_description = gr.Markdown() # Generated Images with gr.Row(): with gr.Column(): original_image = gr.Image( label="Original Style", show_label=True, height=300 ) with gr.Column(): loss_image = gr.Image( label="Color Enhanced", show_label=True, height=300 ) # Example Gallery gr.Markdown( """
## 🎆 Example Gallery Compare original and enhanced versions for each style:
""" ) # Example Images with gr.Row(): try: output_dir = Path("Outputs") original_dir = output_dir enhanced_dir = output_dir / "Color_Enhanced" if enhanced_dir.exists(): original_images = { Path(f).stem.split('_example')[0]: f for f in original_dir.glob("*.webp") if '_example' in f.name } enhanced_images = { Path(f).stem.split('_example')[0]: f for f in enhanced_dir.glob("*.webp") if '_example' in f.name } styles = [ ("ronaldo", "Ronaldo Style"), ("canna_lily", "Canna Lily"), ("three_stooges", "Three Stooges"), ("pop_art", "Pop Art"), ("bird_style", "Bird Style") ] # Create a grid of all styles for style_key, style_name in styles: if style_key in original_images and style_key in enhanced_images: with gr.Row(): gr.Markdown(f"### {style_name}") with gr.Row(): with gr.Column(scale=1): gr.Image( value=str(original_images[style_key]), label="Original", show_label=True, height=180 ) with gr.Column(scale=1): gr.Image( value=str(enhanced_images[style_key]), label="Color Enhanced", show_label=True, height=180 ) # Add a small spacing between styles gr.Markdown("
") except Exception as e: print(f"Error in example gallery: {e}") gr.Markdown(f"Error loading example gallery: {str(e)}") # Info section with gr.Row(): with gr.Column(): gr.Markdown( """
## 🎨 Style Guide | Style | Best For | |-------|----------| | **Ronaldo Style** | Dynamic sports scenes, action shots, celebrations | | **Canna Lily** | Natural scenes, floral compositions, garden imagery | | **Three Stooges** | Comedy, humor, expressive character portraits | | **Pop Art** | Vibrant artwork, bold colors, stylized designs | | **Bird Style** | Wildlife, nature scenes, peaceful landscapes | *Choose the style that best matches your creative vision*
""" ) with gr.Column(): gr.Markdown( """
## 🔍 Color Enhancement Technology Our advanced color processing uses distance loss to enhance your images: ### 🌈 Color Dynamics - **Vibrancy**: Intensifies colors naturally - **Contrast**: Improves depth and definition - **Balance**: Optimizes color relationships ### 🎨 Technical Features - **Channel Separation**: RGB optimization - **Loss Function**: Mathematical color enhancement - **Real-time Processing**: Dynamic adjustments ### ✨ Benefits - Richer, more vivid colors - Clearer color boundaries - Reduced color muddiness - Enhanced artistic impact *Our color distance loss technology mathematically optimizes RGB channel relationships*
""" ) # Update style description on change def update_style_description(style_idx): descriptions = [ "Perfect for capturing dynamic sports moments and celebrations", "Ideal for creating beautiful natural and floral compositions", "Great for adding humor and expressiveness to your scenes", "Transform your ideas into vibrant pop art masterpieces", "Specialized in capturing the beauty of nature and wildlife" ] styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"] return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}" style_radio.change( fn=update_style_description, inputs=style_radio, outputs=style_description ) generate_btn.click( fn=generate_single_style, inputs=[prompt, style_radio], outputs=[error_message, original_image, loss_image] ) # Launch the app if __name__ == "__main__": iface.launch( share=True, show_error=True )