Spaces:
Running
Running
import torch | |
from diffusers import StableDiffusionPipeline | |
from torch import autocast | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import os | |
from pathlib import Path | |
import traceback | |
import glob | |
from PIL import Image | |
# Reuse the same load_learned_embed_in_clip and Distance_loss functions | |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
trained_token = list(loaded_learned_embeds.keys())[0] | |
embeds = loaded_learned_embeds[trained_token] | |
# Get the expected dimension from the text encoder | |
expected_dim = text_encoder.get_input_embeddings().weight.shape[1] | |
current_dim = embeds.shape[0] | |
# Resize embeddings if dimensions don't match | |
if current_dim != expected_dim: | |
print(f"Resizing embedding from {current_dim} to {expected_dim}") | |
# Option 1: Truncate or pad with zeros | |
if current_dim > expected_dim: | |
embeds = embeds[:expected_dim] | |
else: | |
embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0) | |
# Reshape to match expected dimensions | |
embeds = embeds.unsqueeze(0) # Add batch dimension | |
# Cast to dtype of text_encoder | |
dtype = text_encoder.get_input_embeddings().weight.dtype | |
embeds = embeds.to(dtype) | |
# Add the token in tokenizer | |
token = token if token is not None else trained_token | |
num_added_tokens = tokenizer.add_tokens(token) | |
# Resize the token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Get the id for the token and assign the embeds | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0] | |
return token | |
def Distance_loss(images): | |
# Ensure we're working with gradients | |
if not images.requires_grad: | |
images = images.detach().requires_grad_(True) | |
# Convert to float32 and normalize | |
images = images.float() / 2 + 0.5 | |
# Get RGB channels | |
red = images[:,0:1] | |
green = images[:,1:2] | |
blue = images[:,2:3] | |
# Calculate color distances using L2 norm | |
rg_distance = ((red - green) ** 2).mean() | |
rb_distance = ((red - blue) ** 2).mean() | |
gb_distance = ((green - blue) ** 2).mean() | |
return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss | |
class StyleGenerator: | |
_instance = None | |
def get_instance(cls): | |
if cls._instance is None: | |
cls._instance = cls() | |
return cls._instance | |
def __init__(self): | |
self.pipe = None | |
self.style_tokens = [] | |
self.styles = [ | |
"ronaldo", | |
"canna-lily-flowers102", | |
"threestooges", | |
"pop_art", | |
"bird_style" | |
] | |
self.style_names = [ | |
"Ronaldo", | |
"Canna Lily", | |
"Three Stooges", | |
"Pop Art", | |
"Bird Style" | |
] | |
self.is_initialized = False | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
if self.device == "cpu": | |
print("NVIDIA GPU not found. Running on CPU (this will be slower)") | |
def initialize_model(self): | |
if self.is_initialized: | |
return | |
try: | |
print("Initializing Stable Diffusion model...") | |
model_id = "runwayml/stable-diffusion-v1-5" | |
self.pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
safety_checker=None | |
) | |
self.pipe = self.pipe.to(self.device) | |
# Load style embeddings from current directory | |
current_dir = Path(__file__).parent | |
for style, style_name in zip(self.styles, self.style_names): | |
style_path = current_dir / f"{style}.bin" | |
if not style_path.exists(): | |
raise FileNotFoundError(f"Style embedding not found: {style_path}") | |
print(f"Loading style: {style_name}") | |
token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer) | |
self.style_tokens.append(token) | |
print(f"β Loaded style: {style_name}") | |
self.is_initialized = True | |
print(f"Model initialization complete! Using device: {self.device}") | |
except Exception as e: | |
print(f"Error during initialization: {str(e)}") | |
print(traceback.format_exc()) | |
raise | |
def generate_single_style(self, prompt, selected_style): | |
try: | |
# Find the index of the selected style | |
style_idx = self.style_names.index(self.style_names[selected_style]) | |
# Generate single image with selected style | |
styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}" | |
# Set seed for reproducibility | |
generator_seed = 42 | |
torch.manual_seed(generator_seed) | |
if self.device == "cuda": | |
torch.cuda.manual_seed(generator_seed) | |
# Generate base image | |
with autocast(self.device): | |
base_image = self.pipe( | |
styled_prompt, | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
generator=torch.Generator(self.device).manual_seed(generator_seed) | |
).images[0] | |
# Generate same image with loss | |
with autocast(self.device): | |
loss_image = self.pipe( | |
styled_prompt, | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
callback=self.callback_fn, | |
callback_steps=5, | |
generator=torch.Generator(self.device).manual_seed(generator_seed) | |
).images[0] | |
return base_image, loss_image | |
except Exception as e: | |
print(f"Error in generate_single_style: {e}") | |
raise | |
def callback_fn(self, i, t, latents): | |
if i % 5 == 0: # Apply loss every 5 steps | |
try: | |
# Create a copy that requires gradients | |
latents_copy = latents.detach().clone() | |
latents_copy.requires_grad_(True) | |
# Compute loss | |
loss = Distance_loss(latents_copy) | |
# Compute gradients | |
if loss.requires_grad: | |
grads = torch.autograd.grad( | |
outputs=loss, | |
inputs=latents_copy, | |
allow_unused=True, | |
retain_graph=False | |
)[0] | |
if grads is not None: | |
# Apply gradients to original latents | |
return latents - 0.1 * grads.detach() | |
except Exception as e: | |
print(f"Error in callback: {e}") | |
return latents | |
def generate_single_style(prompt, selected_style): | |
try: | |
generator = StyleGenerator.get_instance() | |
if not generator.is_initialized: | |
generator.initialize_model() | |
base_image, loss_image = generator.generate_single_style(prompt, selected_style) | |
return [ | |
gr.update(visible=False), # error_message | |
base_image, # original_image | |
loss_image # loss_image | |
] | |
except Exception as e: | |
print(f"Error in generate_single_style: {e}") | |
return [ | |
gr.update(value=f"Error: {str(e)}", visible=True), # error_message | |
None, # original_image | |
None # loss_image | |
] | |
# Add at the start of your script | |
def debug_image_paths(): | |
output_dir = Path("Outputs") | |
enhanced_dir = output_dir / "Color_Enhanced" | |
print(f"\nChecking image paths:") | |
print(f"Current working directory: {Path.cwd()}") | |
print(f"Looking for images in: {enhanced_dir.absolute()}") | |
if enhanced_dir.exists(): | |
print("\nFound files:") | |
for file in enhanced_dir.glob("*.webp"): | |
print(f"- {file.name}") | |
else: | |
print("\nDirectory not found!") | |
# Call this function before creating the interface | |
debug_image_paths() | |
# Create a more beautiful interface with custom styling | |
with gr.Blocks(css=""" | |
.gradio-container { | |
background-color: #1f2937 !important; | |
} | |
.dark-theme { | |
background-color: #111827; | |
border-radius: 10px; | |
padding: 20px; | |
margin: 10px; | |
border: 1px solid #374151; | |
color: #f3f4f6; | |
} | |
/* Enhanced Tab Styling */ | |
.tabs.svelte-710i53 { | |
margin-bottom: 0 !important; | |
} | |
.tab-nav.svelte-710i53 { | |
background: transparent !important; | |
border: none !important; | |
padding: 12px 24px !important; | |
margin: 0 2px !important; | |
color: #9CA3AF !important; | |
font-weight: 500 !important; | |
transition: all 0.2s ease !important; | |
border-bottom: 2px solid transparent !important; | |
} | |
.tab-nav.svelte-710i53.selected { | |
background: transparent !important; | |
color: #F3F4F6 !important; | |
border-bottom: 2px solid #6366F1 !important; | |
} | |
.tab-nav.svelte-710i53:hover { | |
color: #F3F4F6 !important; | |
border-bottom: 2px solid #4F46E5 !important; | |
} | |
""") as iface: | |
# Header section | |
gr.Markdown( | |
""" | |
<div class="dark-theme" style="text-align: center;"> | |
# π¨ AI Style Transfer Studio | |
### Transform your ideas into artistic masterpieces | |
</div> | |
""" | |
) | |
# Controls section | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## π― Controls") | |
prompt = gr.Textbox( | |
label="What would you like to create?", | |
placeholder="e.g., a soccer player celebrating a goal", | |
lines=3 | |
) | |
style_radio = gr.Radio( | |
choices=[ | |
"Ronaldo Style", | |
"Canna Lily", | |
"Three Stooges", | |
"Pop Art", | |
"Bird Style" | |
], | |
label="Choose Your Style", | |
value="Ronaldo Style", | |
type="index" | |
) | |
generate_btn = gr.Button( | |
"π Generate Artwork", | |
variant="primary", | |
size="lg" | |
) | |
error_message = gr.Markdown(visible=False) | |
style_description = gr.Markdown() | |
# Generated Images | |
with gr.Row(): | |
with gr.Column(): | |
original_image = gr.Image( | |
label="Original Style", | |
show_label=True, | |
height=300 | |
) | |
with gr.Column(): | |
loss_image = gr.Image( | |
label="Color Enhanced", | |
show_label=True, | |
height=300 | |
) | |
# Example Gallery | |
gr.Markdown( | |
""" | |
<div class="dark-theme"> | |
## π Example Gallery | |
Compare original and enhanced versions for each style: | |
</div> | |
""" | |
) | |
# Example Images | |
with gr.Row(): | |
try: | |
output_dir = Path("Outputs") | |
original_dir = output_dir | |
enhanced_dir = output_dir / "Color_Enhanced" | |
if enhanced_dir.exists(): | |
original_images = { | |
Path(f).stem.split('_example')[0]: f | |
for f in original_dir.glob("*.webp") | |
if '_example' in f.name | |
} | |
enhanced_images = { | |
Path(f).stem.split('_example')[0]: f | |
for f in enhanced_dir.glob("*.webp") | |
if '_example' in f.name | |
} | |
styles = [ | |
("ronaldo", "Ronaldo Style"), | |
("canna_lily", "Canna Lily"), | |
("three_stooges", "Three Stooges"), | |
("pop_art", "Pop Art"), | |
("bird_style", "Bird Style") | |
] | |
# Create a grid of all styles | |
for style_key, style_name in styles: | |
if style_key in original_images and style_key in enhanced_images: | |
with gr.Row(): | |
gr.Markdown(f"### {style_name}") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Image( | |
value=str(original_images[style_key]), | |
label="Original", | |
show_label=True, | |
height=180 | |
) | |
with gr.Column(scale=1): | |
gr.Image( | |
value=str(enhanced_images[style_key]), | |
label="Color Enhanced", | |
show_label=True, | |
height=180 | |
) | |
# Add a small spacing between styles | |
gr.Markdown("<div style='margin: 10px 0;'></div>") | |
except Exception as e: | |
print(f"Error in example gallery: {e}") | |
gr.Markdown(f"Error loading example gallery: {str(e)}") | |
# Info section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
<div class="dark-theme"> | |
## π¨ Style Guide | |
| Style | Best For | | |
|-------|----------| | |
| **Ronaldo Style** | Dynamic sports scenes, action shots, celebrations | | |
| **Canna Lily** | Natural scenes, floral compositions, garden imagery | | |
| **Three Stooges** | Comedy, humor, expressive character portraits | | |
| **Pop Art** | Vibrant artwork, bold colors, stylized designs | | |
| **Bird Style** | Wildlife, nature scenes, peaceful landscapes | | |
*Choose the style that best matches your creative vision* | |
</div> | |
""" | |
) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
<div class="dark-theme"> | |
## π Color Enhancement Technology | |
Our advanced color processing uses distance loss to enhance your images: | |
### π Color Dynamics | |
- **Vibrancy**: Intensifies colors naturally | |
- **Contrast**: Improves depth and definition | |
- **Balance**: Optimizes color relationships | |
### π¨ Technical Features | |
- **Channel Separation**: RGB optimization | |
- **Loss Function**: Mathematical color enhancement | |
- **Real-time Processing**: Dynamic adjustments | |
### β¨ Benefits | |
- Richer, more vivid colors | |
- Clearer color boundaries | |
- Reduced color muddiness | |
- Enhanced artistic impact | |
<small>*Our color distance loss technology mathematically optimizes RGB channel relationships*</small> | |
</div> | |
""" | |
) | |
# Update style description on change | |
def update_style_description(style_idx): | |
descriptions = [ | |
"Perfect for capturing dynamic sports moments and celebrations", | |
"Ideal for creating beautiful natural and floral compositions", | |
"Great for adding humor and expressiveness to your scenes", | |
"Transform your ideas into vibrant pop art masterpieces", | |
"Specialized in capturing the beauty of nature and wildlife" | |
] | |
styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"] | |
return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}" | |
style_radio.change( | |
fn=update_style_description, | |
inputs=style_radio, | |
outputs=style_description | |
) | |
generate_btn.click( | |
fn=generate_single_style, | |
inputs=[prompt, style_radio], | |
outputs=[error_message, original_image, loss_image] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch( | |
share=True, | |
show_error=True | |
) |