|
import os |
|
import gc |
|
import re |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import json |
|
import spaces |
|
import config |
|
import utils |
|
import logging |
|
import time |
|
from datetime import datetime |
|
from typing import List, Dict, Tuple, Optional |
|
from PIL import Image, PngImagePlugin |
|
from diffusers.models import AutoencoderKL |
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline |
|
from transformers import pipeline as translation_pipeline |
|
|
|
from config import ( |
|
MODEL, |
|
MIN_IMAGE_SIZE, |
|
MAX_IMAGE_SIZE, |
|
USE_TORCH_COMPILE, |
|
ENABLE_CPU_OFFLOAD, |
|
OUTPUT_DIR, |
|
DEFAULT_NEGATIVE_PROMPT, |
|
DEFAULT_ASPECT_RATIO, |
|
examples, |
|
sampler_list, |
|
aspect_ratios, |
|
style_list, |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1" |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
|
class GenerationError(Exception): |
|
"""Custom exception for generation errors""" |
|
pass |
|
|
|
def translate_if_korean(prompt: str) -> str: |
|
"""ํ๋กฌํํธ์ ํ๊ธ์ด ํฌํจ๋์ด ์์ผ๋ฉด ์์ด๋ก ๋ฒ์ญ""" |
|
if re.search(r'[ใฑ-ใ
ใ
-ใ
ฃ๊ฐ-ํฃ]', prompt): |
|
logger.info("Korean detected in prompt. Translating to English...") |
|
try: |
|
translation = translator(prompt)[0]['translation_text'] |
|
logger.info(f"Translation result: {translation}") |
|
return translation |
|
except Exception as e: |
|
logger.error(f"Translation error: {e}") |
|
|
|
return prompt |
|
return prompt |
|
|
|
def validate_prompt(prompt: str) -> str: |
|
"""Validate and clean up the input prompt.""" |
|
if not isinstance(prompt, str): |
|
raise GenerationError("Prompt must be a string") |
|
try: |
|
|
|
prompt = prompt.encode('utf-8').decode('utf-8') |
|
|
|
prompt = prompt.replace("!,", "! ,") |
|
except UnicodeError: |
|
raise GenerationError("Invalid characters in prompt") |
|
|
|
|
|
if not prompt or prompt.isspace(): |
|
raise GenerationError("Prompt cannot be empty") |
|
|
|
|
|
prompt = translate_if_korean(prompt) |
|
return prompt.strip() |
|
|
|
def validate_dimensions(width: int, height: int) -> None: |
|
"""Validate image dimensions.""" |
|
if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: |
|
raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
|
|
|
if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: |
|
raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
|
|
|
@spaces.GPU |
|
def generate( |
|
prompt: str, |
|
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, |
|
seed: int = 0, |
|
custom_width: int = 1024, |
|
custom_height: int = 1024, |
|
guidance_scale: float = 6.0, |
|
num_inference_steps: int = 25, |
|
sampler: str = "Euler a", |
|
aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO, |
|
style_selector: str = "(None)", |
|
use_upscaler: bool = False, |
|
upscaler_strength: float = 0.55, |
|
upscale_by: float = 1.5, |
|
add_quality_tags: bool = True, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
) -> Tuple[List[str], Dict]: |
|
"""Generate images based on the given parameters.""" |
|
start_time = time.time() |
|
upscaler_pipe = None |
|
backup_scheduler = None |
|
|
|
try: |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
prompt = validate_prompt(prompt) |
|
if negative_prompt: |
|
negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') |
|
|
|
validate_dimensions(custom_width, custom_height) |
|
|
|
|
|
generator = utils.seed_everything(seed) |
|
width, height = utils.aspect_ratio_handler( |
|
aspect_ratio_selector, |
|
custom_width, |
|
custom_height, |
|
) |
|
|
|
|
|
if add_quality_tags: |
|
prompt = f"masterpiece, high score, great score, absurdres, {prompt}" |
|
|
|
prompt, negative_prompt = utils.preprocess_prompt( |
|
styles, style_selector, prompt, negative_prompt |
|
) |
|
|
|
width, height = utils.preprocess_image_dimensions(width, height) |
|
|
|
|
|
backup_scheduler = pipe.scheduler |
|
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler) |
|
|
|
if use_upscaler: |
|
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) |
|
|
|
|
|
metadata = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"resolution": f"{width} x {height}", |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"style_preset": style_selector, |
|
"seed": seed, |
|
"sampler": sampler, |
|
"Model": "Animagine XL 4.0", |
|
"Model hash": "e3c47aedb0", |
|
} |
|
|
|
if use_upscaler: |
|
new_width = int(width * upscale_by) |
|
new_height = int(height * upscale_by) |
|
metadata["use_upscaler"] = { |
|
"upscale_method": "nearest-exact", |
|
"upscaler_strength": upscaler_strength, |
|
"upscale_by": upscale_by, |
|
"new_resolution": f"{new_width} x {new_height}", |
|
} |
|
else: |
|
metadata["use_upscaler"] = None |
|
|
|
logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}") |
|
|
|
|
|
if use_upscaler: |
|
latents = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
output_type="latent", |
|
).images |
|
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by) |
|
images = upscaler_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=upscaled_latents, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
strength=upscaler_strength, |
|
generator=generator, |
|
output_type="pil", |
|
).images |
|
else: |
|
images = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
output_type="pil", |
|
).images |
|
|
|
|
|
if images: |
|
total = len(images) |
|
image_paths = [] |
|
for idx, image in enumerate(images, 1): |
|
progress(idx/total, desc="Saving images...") |
|
path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB) |
|
image_paths.append(path) |
|
logger.info(f"Image {idx}/{total} saved as {path}") |
|
|
|
generation_time = time.time() - start_time |
|
logger.info(f"Generation completed successfully in {generation_time:.2f} seconds") |
|
metadata["generation_time"] = f"{generation_time:.2f}s" |
|
|
|
return image_paths, metadata |
|
|
|
except GenerationError as e: |
|
logger.warning(f"Generation validation error: {str(e)}") |
|
raise gr.Error(str(e)) |
|
except Exception as e: |
|
logger.exception("Unexpected error during generation") |
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
finally: |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
if upscaler_pipe is not None: |
|
del upscaler_pipe |
|
|
|
if backup_scheduler is not None and pipe is not None: |
|
pipe.scheduler = backup_scheduler |
|
|
|
utils.free_memory() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
try: |
|
logger.info("Loading VAE and pipeline...") |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
) |
|
pipe = utils.load_pipeline(MODEL, device, vae=vae) |
|
logger.info("Pipeline loaded successfully on GPU!") |
|
except Exception as e: |
|
logger.error(f"Error loading VAE, falling back to default: {e}") |
|
pipe = utils.load_pipeline(MODEL, device) |
|
else: |
|
logger.warning("CUDA not available, running on CPU") |
|
pipe = None |
|
|
|
|
|
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} |
|
|
|
|
|
custom_css = """ |
|
/* ๋ฐฐ๊ฒฝ ๋ฐ ๊ธ์ ์์ ๋ณ๊ฒฝ */ |
|
body { |
|
background-color: #f7f9fc; |
|
color: #333; |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
} |
|
|
|
/* ํค๋ ์คํ์ผ */ |
|
.header { |
|
text-align: center; |
|
padding: 20px; |
|
} |
|
.header .title { |
|
font-size: 3em; |
|
font-weight: bold; |
|
color: #2c3e50; |
|
} |
|
.header .subtitle { |
|
font-size: 1.2em; |
|
color: #7f8c8d; |
|
} |
|
a { |
|
text-decoration: none; |
|
color: #3498db; |
|
} |
|
|
|
/* Discord ๋ฒํผ ์คํ์ผ */ |
|
.discord-btn { |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
padding: 10px 20px; |
|
background: #7289da; |
|
color: white; |
|
border-radius: 8px; |
|
font-weight: bold; |
|
margin-top: 20px; |
|
} |
|
.discord-btn:hover { |
|
background: #5b6eae; |
|
} |
|
.discord-icon { |
|
width: 24px; |
|
height: 24px; |
|
margin-right: 8px; |
|
} |
|
|
|
/* Gradio ๊ฐค๋ฌ๋ฆฌ ์คํ์ผ ๊ฐ์ */ |
|
.gradio-gallery { |
|
border: none; |
|
box-shadow: none; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css, theme="default") as demo: |
|
|
|
gr.HTML( |
|
""" |
|
<div class="header"> |
|
<div class="title">Multilingual Animagine</div> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Tabs(): |
|
with gr.TabItem("Generate"): |
|
with gr.Group(): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
lines=4, |
|
placeholder="Describe what you want to generate...", |
|
info="Enter your image generation prompt here. ํ๊ธ ์
๋ ฅ ์ ์๋์ผ๋ก ์์ด๋ก ๋ฒ์ญ๋ฉ๋๋ค.", |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
lines=4, |
|
placeholder="Describe what you want to avoid", |
|
value=DEFAULT_NEGATIVE_PROMPT, |
|
info="Specify elements you don't want in the image.", |
|
) |
|
add_quality_tags = gr.Checkbox( |
|
label="Quality Tags", |
|
value=True, |
|
info="Automatically add quality-enhancing tags to your prompt.", |
|
) |
|
with gr.Accordion(label="More Settings", open=False): |
|
with gr.Column(): |
|
aspect_ratio_selector = gr.Radio( |
|
label="Aspect Ratio", |
|
choices=aspect_ratios, |
|
value=DEFAULT_ASPECT_RATIO, |
|
container=True, |
|
info="Choose the dimensions of your image.", |
|
) |
|
with gr.Row(visible=False) as custom_resolution: |
|
custom_width = gr.Slider( |
|
label="Width", |
|
minimum=MIN_IMAGE_SIZE, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=8, |
|
value=1024, |
|
info=f"Image width (between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
|
) |
|
custom_height = gr.Slider( |
|
label="Height", |
|
minimum=MIN_IMAGE_SIZE, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=8, |
|
value=1024, |
|
info=f"Image height (between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
|
) |
|
with gr.Accordion(label="Advanced Parameters", open=False): |
|
with gr.Row(): |
|
style_selector = gr.Dropdown( |
|
label="Style Preset", |
|
choices=list(styles.keys()), |
|
value="(None)", |
|
info="Apply a predefined style to your generation.", |
|
) |
|
sampler = gr.Dropdown( |
|
label="Sampler", |
|
choices=sampler_list, |
|
value="Euler a", |
|
info="Different samplers can produce varying results.", |
|
) |
|
with gr.Row(): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=utils.MAX_SEED, |
|
step=1, |
|
value=0, |
|
info="Set a specific seed for reproducible results.", |
|
) |
|
randomize_seed = gr.Checkbox( |
|
label="Randomize seed", |
|
value=True, |
|
info="Generate a new random seed for each image.", |
|
) |
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", |
|
minimum=1, |
|
maximum=12, |
|
step=0.1, |
|
value=6.0, |
|
info="Higher values make the image more closely match your prompt.", |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Inference steps", |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=25, |
|
info="More steps generally yield higher quality images.", |
|
) |
|
with gr.Row(): |
|
use_upscaler = gr.Checkbox( |
|
label="Use Upscaler", |
|
value=False, |
|
info="Enable high-resolution upscaling.", |
|
) |
|
upscaler_strength = gr.Slider( |
|
label="Upscaler Strength", |
|
minimum=0, |
|
maximum=1, |
|
step=0.05, |
|
value=0.55, |
|
visible=False, |
|
info="Control how strongly the upscaler affects the image.", |
|
) |
|
upscale_by = gr.Slider( |
|
label="Upscale By", |
|
minimum=1, |
|
maximum=1.5, |
|
step=0.1, |
|
value=1.5, |
|
visible=False, |
|
info="Multiplier for the final image resolution.", |
|
) |
|
with gr.TabItem("Examples"): |
|
gr.Markdown( |
|
""" |
|
### Example Prompts |
|
- **Scenic Landscape:** A breathtaking view of a mountain landscape during sunrise. |
|
- **Cyberpunk City:** A futuristic cyberpunk city with neon lights and towering skyscrapers. |
|
- **Fantasy Character:** A majestic wizard with a long beard and glowing magical staff. |
|
""" |
|
) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=prompt, |
|
outputs=[], |
|
cache_examples=CACHE_EXAMPLES, |
|
) |
|
run_button = gr.Button("Generate", variant="primary", elem_id="generate-button") |
|
|
|
with gr.Column(scale=4): |
|
result = gr.Gallery( |
|
label="Generated Images", |
|
columns=2, |
|
height="600px", |
|
show_label=True, |
|
elem_classes="gradio-gallery", |
|
) |
|
with gr.Accordion(label="Generation Parameters", open=False): |
|
gr_metadata = gr.JSON( |
|
label="Image Metadata", |
|
show_label=True, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<div style="width:100%; display:flex; justify-content:center;"> |
|
<a href="https://discord.gg/openfreeai" target="_blank" class="discord-btn"> |
|
<svg class="discord-icon" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 127.14 96.36"> |
|
<path fill="currentColor" d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/> |
|
</svg> |
|
<span class="discord-text">Join our Discord Server</span> |
|
</a> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
use_upscaler.change( |
|
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], |
|
inputs=use_upscaler, |
|
outputs=[upscaler_strength, upscale_by], |
|
queue=False, |
|
api_name=False, |
|
) |
|
aspect_ratio_selector.change( |
|
fn=lambda x: gr.update(visible=x == "Custom"), |
|
inputs=aspect_ratio_selector, |
|
outputs=custom_resolution, |
|
queue=False, |
|
api_name=False, |
|
) |
|
|
|
|
|
gr.on( |
|
triggers=[ |
|
prompt.submit, |
|
negative_prompt.submit, |
|
run_button.click, |
|
], |
|
fn=utils.randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed, |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=lambda: gr.update(interactive=False, value="Generating..."), |
|
outputs=run_button, |
|
).then( |
|
fn=generate, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
seed, |
|
custom_width, |
|
custom_height, |
|
guidance_scale, |
|
num_inference_steps, |
|
sampler, |
|
aspect_ratio_selector, |
|
style_selector, |
|
use_upscaler, |
|
upscaler_strength, |
|
upscale_by, |
|
add_quality_tags, |
|
], |
|
outputs=[result, gr_metadata], |
|
).then( |
|
fn=lambda: gr.update(interactive=True, value="Generate"), |
|
outputs=run_button, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB) |
|
|