import torch from diffusers import StableDiffusionPipeline from torch import autocast import gradio as gr from huggingface_hub import hf_hub_download import os from pathlib import Path import traceback import glob from PIL import Image # Reuse the same load_learned_embed_in_clip and Distance_loss functions def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") trained_token = list(loaded_learned_embeds.keys())[0] embeds = loaded_learned_embeds[trained_token] # Get the expected dimension from the text encoder expected_dim = text_encoder.get_input_embeddings().weight.shape[1] current_dim = embeds.shape[0] # Resize embeddings if dimensions don't match if current_dim != expected_dim: print(f"Resizing embedding from {current_dim} to {expected_dim}") # Option 1: Truncate or pad with zeros if current_dim > expected_dim: embeds = embeds[:expected_dim] else: embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0) # Reshape to match expected dimensions embeds = embeds.unsqueeze(0) # Add batch dimension # Cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype embeds = embeds.to(dtype) # Add the token in tokenizer token = token if token is not None else trained_token num_added_tokens = tokenizer.add_tokens(token) # Resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # Get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0] return token def Distance_loss(images): # Ensure we're working with gradients if not images.requires_grad: images = images.detach().requires_grad_(True) # Convert to float32 and normalize images = images.float() / 2 + 0.5 # Get RGB channels red = images[:,0:1] green = images[:,1:2] blue = images[:,2:3] # Calculate color distances using L2 norm rg_distance = ((red - green) ** 2).mean() rb_distance = ((red - blue) ** 2).mean() gb_distance = ((green - blue) ** 2).mean() return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss class StyleGenerator: _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): self.pipe = None self.style_tokens = [] self.styles = [ "ronaldo", "canna-lily-flowers102", "threestooges", "pop_art", "bird_style" ] self.style_names = [ "Ronaldo", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style" ] self.is_initialized = False self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.device == "cpu": print("NVIDIA GPU not found. Running on CPU (this will be slower)") def initialize_model(self): if self.is_initialized: return try: print("Initializing Stable Diffusion model...") model_id = "runwayml/stable-diffusion-v1-5" self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None ) self.pipe = self.pipe.to(self.device) # Load style embeddings from current directory current_dir = Path(__file__).parent for style, style_name in zip(self.styles, self.style_names): style_path = current_dir / f"{style}.bin" if not style_path.exists(): raise FileNotFoundError(f"Style embedding not found: {style_path}") print(f"Loading style: {style_name}") token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer) self.style_tokens.append(token) print(f"✓ Loaded style: {style_name}") self.is_initialized = True print(f"Model initialization complete! Using device: {self.device}") except Exception as e: print(f"Error during initialization: {str(e)}") print(traceback.format_exc()) raise def generate_single_style(self, prompt, selected_style): try: # Find the index of the selected style style_idx = self.style_names.index(self.style_names[selected_style]) # Generate single image with selected style styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}" # Set seed for reproducibility generator_seed = 42 torch.manual_seed(generator_seed) if self.device == "cuda": torch.cuda.manual_seed(generator_seed) # Generate base image with autocast(self.device): base_image = self.pipe( styled_prompt, num_inference_steps=50, guidance_scale=7.5, generator=torch.Generator(self.device).manual_seed(generator_seed) ).images[0] # Generate same image with loss with autocast(self.device): loss_image = self.pipe( styled_prompt, num_inference_steps=50, guidance_scale=7.5, callback=self.callback_fn, callback_steps=5, generator=torch.Generator(self.device).manual_seed(generator_seed) ).images[0] return base_image, loss_image except Exception as e: print(f"Error in generate_single_style: {e}") raise def callback_fn(self, i, t, latents): if i % 5 == 0: # Apply loss every 5 steps try: # Create a copy that requires gradients latents_copy = latents.detach().clone() latents_copy.requires_grad_(True) # Compute loss loss = Distance_loss(latents_copy) # Compute gradients if loss.requires_grad: grads = torch.autograd.grad( outputs=loss, inputs=latents_copy, allow_unused=True, retain_graph=False )[0] if grads is not None: # Apply gradients to original latents return latents - 0.1 * grads.detach() except Exception as e: print(f"Error in callback: {e}") return latents def generate_single_style(prompt, selected_style): try: generator = StyleGenerator.get_instance() if not generator.is_initialized: generator.initialize_model() base_image, loss_image = generator.generate_single_style(prompt, selected_style) return [ gr.update(visible=False), # error_message base_image, # original_image loss_image # loss_image ] except Exception as e: print(f"Error in generate_single_style: {e}") return [ gr.update(value=f"Error: {str(e)}", visible=True), # error_message None, # original_image None # loss_image ] # Add at the start of your script def debug_image_paths(): output_dir = Path("Outputs") enhanced_dir = output_dir / "Color_Enhanced" print(f"\nChecking image paths:") print(f"Current working directory: {Path.cwd()}") print(f"Looking for images in: {enhanced_dir.absolute()}") if enhanced_dir.exists(): print("\nFound files:") for file in enhanced_dir.glob("*.webp"): print(f"- {file.name}") else: print("\nDirectory not found!") # Call this function before creating the interface debug_image_paths() # Create a more beautiful interface with custom styling with gr.Blocks(css=""" .gradio-container { background-color: #1f2937 !important; } .dark-theme { background-color: #111827; border-radius: 10px; padding: 20px; margin: 10px; border: 1px solid #374151; color: #f3f4f6; } /* Enhanced Tab Styling */ .tabs.svelte-710i53 { margin-bottom: 0 !important; } .tab-nav.svelte-710i53 { background: transparent !important; border: none !important; padding: 12px 24px !important; margin: 0 2px !important; color: #9CA3AF !important; font-weight: 500 !important; transition: all 0.2s ease !important; border-bottom: 2px solid transparent !important; } .tab-nav.svelte-710i53.selected { background: transparent !important; color: #F3F4F6 !important; border-bottom: 2px solid #6366F1 !important; } .tab-nav.svelte-710i53:hover { color: #F3F4F6 !important; border-bottom: 2px solid #4F46E5 !important; } """) as iface: # Header section gr.Markdown( """