Spaces:
Sleeping
Sleeping
File size: 9,099 Bytes
5f38454 318dc42 5f38454 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 21e3bc8 318dc42 2a256c2 2700e0a 318dc42 2700e0a fe29fc7 11ba365 fe29fc7 2700e0a 318dc42 fe29fc7 318dc42 4d9b874 318dc42 2a256c2 318dc42 2700e0a 11ba365 318dc42 2700e0a 318dc42 2a256c2 318dc42 2700e0a 11ba365 2700e0a 318dc42 2700e0a 11ba365 2700e0a 318dc42 2700e0a 318dc42 2700e0a 11ba365 318dc42 11ba365 318dc42 2700e0a 318dc42 2700e0a 11ba365 2700e0a 989094e cf527f7 fe29fc7 a3b38cb 318dc42 fe29fc7 318dc42 fe29fc7 2a256c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
import torch
import gradio as gr
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler # Import DPMSolver
# 1. Device and dtype: Correctly determine device and dtype. Use float16 if CUDA is available.
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
print(f"Using device: {torch_device}, dtype: {torch_dtype}") # Helpful for debugging
# 2. Model Path and Loading: Use a more efficient scheduler and reduce memory usage.
model_path = "CompVis/stable-diffusion-v1-4"
# Use DPMSolverMultistepScheduler for faster and higher-quality sampling
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler")
sd_pipeline = DiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch_dtype,
scheduler=scheduler, # Use the DPM scheduler
# low_cpu_mem_usage is deprecated, but still helpful on CPU.
low_cpu_mem_usage=True if torch_device == "cpu" else False,
# Use attention slicing to reduce VRAM usage during inference.
# This has a small performance cost but significantly lowers memory.
safety_checker=None, #Removing the safety checker to avoid false positives blocking image generation
requires_safety_checker=False
).to(torch_device)
# Optimize attention for memory efficiency (if using CUDA)
if torch_device == "cuda":
sd_pipeline.enable_xformers_memory_efficient_attention() # Use xformers if installed!
# OR, if xformers is not available:
# sd_pipeline.enable_attention_slicing() # Use attention slicing (less effective, but built-in)
# 3. Textual Inversion Loading: Load *only* the necessary concepts. Load them one by one.
# This is *much* more memory efficient than loading all at once.
style_token_dict = {
"Illustration Style": '<illustration-style>',
"Line Art": '<line-art>',
"Hitokomoru Style": '<hitokomoru-style-nao>',
"Marc Allante": '<Marc_Allante>',
"Midjourney": '<midjourney-style>',
"Hanfu Anime": '<hanfu-anime-style>',
"Birb Style": '<birb-style>'
}
# Load inversions individually. This is crucial for managing memory.
def load_inversion(concept_name, token):
try:
sd_pipeline.load_textual_inversion(f"sd-concepts-library/{concept_name}", token=token)
print(f"Loaded textual inversion: {concept_name}")
except Exception as e:
print(f"Error loading {concept_name}: {e}")
# Load each style individually.
load_inversion("illustration-style", style_token_dict["Illustration Style"])
load_inversion("line-art", style_token_dict["Line Art"])
load_inversion("hitokomoru-style-nao", style_token_dict["Hitokomoru Style"])
load_inversion("style-of-marc-allante", style_token_dict["Marc Allante"])
load_inversion("midjourney-style", style_token_dict["Midjourney"])
load_inversion("hanfu-anime-style", style_token_dict["Hanfu Anime"])
load_inversion("birb-style", style_token_dict["Birb Style"])
# 4. Guidance Function: Optimized for speed and clarity.
def apply_guidance(image, guidance_method, loss_scale):
img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
loss_scale = loss_scale / 10000.0 # Pre-calculate for efficiency
if guidance_method == 'Grayscale':
gray = tfms.Grayscale(num_output_channels=3)(img_tensor) # keep 3 channels
guided = img_tensor + (gray - img_tensor) * loss_scale
elif guidance_method == 'Bright':
guided = torch.clamp(img_tensor * (1 + loss_scale), 0, 1) # Direct brightness adjustment
elif guidance_method == 'Contrast':
mean = img_tensor.mean()
guided = torch.clamp((img_tensor - mean) * (1 + loss_scale) + mean, 0, 1) # Contrast adjustment
elif guidance_method == 'Symmetry':
flipped = torch.flip(img_tensor, [3])
guided = img_tensor + (flipped - img_tensor) * loss_scale
elif guidance_method == 'Saturation':
# Use torchvision's functional approach for efficiency.
guided = tfms.functional.adjust_saturation(img_tensor, 1 + loss_scale)
guided = torch.clamp(guided, 0, 1)
else:
return image
# Convert back to PIL Image (optimized for conciseness)
guided = tfms.ToPILImage()(guided.squeeze(0).cpu())
return guided
# 5. Inference Function: Use the pipeline efficiently.
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size):
prompt = f"{text} {style_token_dict[style]}"
width, height = map(int, image_size.split('x'))
generator = torch.Generator(device=torch_device).manual_seed(seed)
# Generate image (more concise)
image_pipeline = sd_pipeline(
prompt,
num_inference_steps=inference_step,
guidance_scale=guidance_scale,
generator=generator,
height=height,
width=width,
).images[0]
image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale)
return image_pipeline, image_guide
# 6. Gradio Interface (CSS and HTML remain largely the same, but I've included minor improvements)
css_and_html = """
<style>
/* Your CSS here - mostly unchanged, but I've added a few tweaks */
body {
background: linear-gradient(135deg, #1a1c2c, #4a4e69, #9a8c98);
font-family: 'Arial', sans-serif;
color: #f2e9e4;
margin: 0;
padding: 0;
min-height: 100vh;
}
/* ... (Rest of your CSS) ... */
.gr-box {
background-color: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
border-radius: 0.5em !important; /* Add border-radius */
}
.gr-input, .gr-button, .gr-dropdown, .gr-slider {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #f2e9e4 !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
border-radius: 0.5em !important; /* Add border-radius */
}
/* ... (Rest of your CSS) ... */
</style>
<div id="app-header">
<div class="artifact large"></div>
<div class="artifact medium"></div>
<div class="artifact small"></div>
<h1>Dreamscape Creator</h1>
<p>Unleash your imagination with AI-powered generative art</p>
<div class="concept-container">
<div class="concept"><div class="concept-emoji">π¨</div><div class="concept-description">Illustration Style</div></div>
<div class="concept"><div class="concept-emoji">βοΈ</div><div class="concept-description">Line Art</div></div>
<div class="concept"><div class="concept-emoji">π</div><div class="concept-description">Midjourney Style</div></div>
<div class="concept"><div class="concept-emoji">π</div><div class="concept-description">Hanfu Anime</div></div>
</div>
</div>
"""
with gr.Blocks(css=css_and_html) as demo:
gr.HTML(css_and_html)
with gr.Row():
text = gr.Textbox(label="Prompt", placeholder="Describe your dreamscape...")
style = gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style")
with gr.Row():
inference_step = gr.Slider(1, 50, 20, step=1, label="Inference steps")
guidance_scale = gr.Slider(1, 10, 7.5, step=0.1, label="Guidance scale")
seed = gr.Slider(0, 10000, 42, step=1, label="Seed", randomize=True) # Add randomize
with gr.Row():
guidance_method = gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale")
loss_scale = gr.Slider(100, 10000, 200, step=100, label="Loss scale")
with gr.Row():
image_size = gr.Radio(["256x256", "512x512"], label="Image Size", value="256x256")
with gr.Row():
generate_button = gr.Button("Create Dreamscape", variant="primary")
with gr.Row():
output_image = gr.Image(label="Your Dreamscape", interactive=False) # Disable interaction
output_image_guided = gr.Image(label="Guided Dreamscape", interactive=False) # Disable interaction
generate_button.click(
inference,
inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
outputs=[output_image, output_image_guided]
)
gr.Examples(
examples=[
["Magical Forest with Glowing Trees", 'Birb Style', 40, 7.5, 42, 'Grayscale', 200, "256x256"],
["Ancient Temple Ruins at Sunset", 'Midjourney', 30, 8.0, 123, 'Bright', 5678, "256x256"],
["Japanese garden with cherry blossoms", 'Hitokomoru Style', 40, 7.0, 789, 'Contrast', 250, "256x256"],
],
inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
outputs=[output_image, output_image_guided],
fn=inference,
# cache_examples=True, # Caching can be problematic on Spaces, especially with limited RAM. Disable if needed.
cache_examples=False,
examples_per_page=5
)
if __name__ == "__main__":
demo.launch() |