import torch from diffusers import StableDiffusionPipeline from torch import autocast import gradio as gr from huggingface_hub import hf_hub_download import os from pathlib import Path import traceback # Reuse the same load_learned_embed_in_clip and Distance_loss functions def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") trained_token = list(loaded_learned_embeds.keys())[0] embeds = loaded_learned_embeds[trained_token] # Get the expected dimension from the text encoder expected_dim = text_encoder.get_input_embeddings().weight.shape[1] current_dim = embeds.shape[0] # Resize embeddings if dimensions don't match if current_dim != expected_dim: print(f"Resizing embedding from {current_dim} to {expected_dim}") # Option 1: Truncate or pad with zeros if current_dim > expected_dim: embeds = embeds[:expected_dim] else: embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0) # Reshape to match expected dimensions embeds = embeds.unsqueeze(0) # Add batch dimension # Cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype embeds = embeds.to(dtype) # Add the token in tokenizer token = token if token is not None else trained_token num_added_tokens = tokenizer.add_tokens(token) # Resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # Get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0] return token def Distance_loss(images): # Ensure we're working with gradients if not images.requires_grad: images = images.detach().requires_grad_(True) # Convert to float32 and normalize images = images.float() / 2 + 0.5 # Get RGB channels red = images[:,0:1] green = images[:,1:2] blue = images[:,2:3] # Calculate color distances using L2 norm rg_distance = ((red - green) ** 2).mean() rb_distance = ((red - blue) ** 2).mean() gb_distance = ((green - blue) ** 2).mean() return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss class StyleGenerator: _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): self.pipe = None self.style_tokens = [] self.styles = [ "ronaldo", "canna-lily-flowers102", "threestooges", "pop_art", "bird_style" ] self.style_names = [ "Ronaldo", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style" ] self.is_initialized = False def initialize_model(self): if self.is_initialized: return try: print("Initializing Stable Diffusion model...") model_id = "runwayml/stable-diffusion-v1-5" self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, safety_checker=None ) self.pipe = self.pipe.to("cuda") # Load style embeddings from current directory current_dir = Path(__file__).parent for style, style_name in zip(self.styles, self.style_names): style_path = current_dir / f"{style}.bin" if not style_path.exists(): raise FileNotFoundError(f"Style embedding not found: {style_path}") print(f"Loading style: {style_name}") token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer) self.style_tokens.append(token) print(f"✓ Loaded style: {style_name}") self.is_initialized = True print("Model initialization complete!") except Exception as e: print(f"Error during initialization: {str(e)}") print(traceback.format_exc()) raise def generate_images(self, prompt, apply_loss=False, num_inference_steps=50, guidance_scale=7.5): if not self.is_initialized: self.initialize_model() images = [] style_names = [] try: def callback_fn(i, t, latents): if i % 5 == 0 and apply_loss: try: # Ensure latents are in the correct format and require gradients latents = latents.float() latents.requires_grad_(True) # Compute loss loss = Distance_loss(latents) # Compute gradients manually grads = torch.autograd.grad( outputs=loss, inputs=latents, create_graph=False, retain_graph=False, only_inputs=True )[0] # Update latents with torch.no_grad(): latents = latents - 0.1 * grads except Exception as e: print(f"Error in callback: {e}") return latents return latents for style_token, style_name in zip(self.style_tokens, self.style_names): styled_prompt = f"{prompt}, {style_token}" style_names.append(style_name) # Disable autocast for better gradient computation image = self.pipe( styled_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, callback=callback_fn if apply_loss else None, callback_steps=5 ).images[0] images.append(image) return images, style_names except Exception as e: print(f"Error during image generation: {str(e)}") print(traceback.format_exc()) raise def callback_fn(self, i, t, latents): if i % 5 == 0: # Apply loss every 5 steps try: # Create a copy that requires gradients latents_copy = latents.detach().clone() latents_copy.requires_grad_(True) # Compute loss loss = Distance_loss(latents_copy) # Compute gradients if loss.requires_grad: grads = torch.autograd.grad( outputs=loss, inputs=latents_copy, allow_unused=True, retain_graph=False )[0] if grads is not None: # Apply gradients to original latents return latents - 0.1 * grads.detach() except Exception as e: print(f"Error in callback: {e}") return latents def generate_all_variations(prompt): try: generator = StyleGenerator.get_instance() if not generator.is_initialized: generator.initialize_model() # Generate images without loss regular_images, style_names = generator.generate_images(prompt, apply_loss=False) # Generate images with loss loss_images, _ = generator.generate_images(prompt, apply_loss=True) return regular_images, loss_images, style_names except Exception as e: print(f"Error in generate_all_variations: {str(e)}") print(traceback.format_exc()) raise def gradio_interface(prompt): try: regular_images, loss_images, style_names = generate_all_variations(prompt) return ( regular_images, # Just return the images directly loss_images # Just return the images directly ) except Exception as e: print(f"Error in interface: {str(e)}") print(traceback.format_exc()) # Return empty lists in case of error return [], [] # Create a more beautiful interface with custom styling with gr.Blocks(css=""" .gradio-container { background-color: #1f2937 !important; } .dark-theme { background-color: #111827; border-radius: 10px; padding: 20px; margin: 10px; border: 1px solid #374151; color: #f3f4f6; } """) as iface: # Header section with dark theme gr.Markdown( """