Gopalag's picture
Update app.py
155ee81 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
STYLE_OPTIONS = {
"Vintage": "vintage style, retro aesthetic, aged appearance",
"Realistic": "photorealistic, detailed, true-to-life",
"Geometric": "geometric shapes, precise lines, mathematical patterns",
"Abstract": "abstract design, non-representational, artistic",
"Minimalist": "simple, clean lines, understated",
"Bohemian": "boho style, free-spirited, eclectic",
"Traditional": "classical design, timeless patterns",
"Contemporary": "modern style, current trends"
}
FABRIC_OPTIONS = {
"None": "",
"Cotton": "cotton textile texture, natural fiber appearance",
"Silk": "silk fabric texture, smooth and lustrous",
"Linen": "linen texture, natural weave pattern",
"Velvet": "velvet texture, plush surface",
"Canvas": "canvas texture, sturdy weave pattern",
"Wool": "wool texture, natural fiber appearance"
}
def enhance_prompt_for_pattern(prompt, style, fabric):
"""Add specific terms to ensure seamless, tileable patterns with style and fabric considerations."""
pattern_terms = [
"seamless pattern",
"tileable textile design",
"repeating pattern",
"high-quality fabric design",
"continuous pattern",
]
enhanced_prompt = f"{prompt}, {random.choice(pattern_terms)}"
if style and style != "None":
enhanced_prompt += f", {STYLE_OPTIONS[style]}"
if fabric and fabric != "None":
enhanced_prompt += f", {FABRIC_OPTIONS[fabric]}"
enhanced_prompt += ", suitable for textile printing, high-quality fabric design, seamless edges"
return enhanced_prompt
def add_logo(image):
"""Add logo to the bottom right corner of the image."""
try:
logo = Image.open('logo.png')
# Resize logo to be proportional to image size (e.g., 10% of image width)
logo_width = int(image.size[0] * 0.2)
logo_ratio = logo.size[1] / logo.size[0]
logo_height = int(logo_width * logo_ratio)
logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
# If logo has alpha channel, create a copy of the image to paste onto
if logo.mode == 'RGBA':
temp_img = image.copy()
# Calculate position for bottom right corner with small padding
position = (image.size[0] - logo_width - 20, image.size[1] - logo_height - 20)
temp_img.paste(logo, position, logo)
return temp_img
else:
# For non-transparent logos
temp_img = image.copy()
position = (image.size[0] - logo_width - 20, image.size[1] - logo_height - 20)
temp_img.paste(logo, position)
return temp_img
except Exception as e:
print(f"Error adding logo: {e}")
return image
def create_fabric_preview(image):
"""Create a fabric preview by tiling the pattern."""
# Create a 4x2 grid of the pattern
width, height = image.size
preview = Image.new('RGB', (width * 4, height * 2))
for y in range(2):
for x in range(4):
preview.paste(image, (x * width, y * height))
# Add logo to the preview
preview = add_logo(preview)
return preview
@spaces.GPU()
def infer(prompt, style, fabric, seed=42, randomize_seed=False, width=1024, height=1024,
num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
enhanced_prompt = enhance_prompt_for_pattern(prompt, style, fabric)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=enhanced_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Convert to PIL Image for processing
pil_image = image
if not isinstance(image, Image.Image):
pil_image = Image.fromarray(np.uint8(image))
# Add logo to single pattern
pattern_with_logo = add_logo(pil_image)
# Create fabric preview
fabric_preview = create_fabric_preview(pil_image)
return pattern_with_logo, fabric_preview, seed
examples = [
["geometric Art Deco shapes in gold and navy", "Geometric", "None"],
["abstract watercolor spots in pastel colors", "Abstract", "Silk"],
["traditional paisley design in earth tones", "Traditional", "Linen"],
["delicate floral motifs with small roses and leaves tileable textile design", "Vintage", "Cotton"],
["modern minimalist lines and circles", "Minimalist", "Canvas"],
]
# Enhanced CSS for better visual design and mobile responsiveness
css = """
#col-container {
margin: 0 auto;
max-width: 1200px !important;
padding: 20px;
}
.main-title {
text-align: center;
color: #2d3748;
margin-bottom: 1rem;
font-family: 'Poppins', sans-serif;
}
.subtitle {
text-align: center;
color: #4a5568;
margin-bottom: 2rem;
font-family: 'Inter', sans-serif;
font-size: 0.95rem;
line-height: 1.5;
}
.pattern-input {
border: 2px solid #e2e8f0;
border-radius: 10px;
padding: 12px !important;
margin-bottom: 1rem !important;
font-size: 1rem;
transition: all 0.3s ease;
}
.pattern-input:focus {
border-color: #4299e1;
box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.1);
}
.generate-button {
background-color: #4299e1 !important;
color: white !important;
padding: 12px 24px !important;
border-radius: 8px !important;
font-weight: 600 !important;
transition: all 0.3s ease !important;
}
.generate-button:hover {
background-color: #3182ce !important;
transform: translateY(-1px);
}
.result-image {
border-radius: 12px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
margin-top: 1rem;
}
.advanced-settings {
margin-top: 1.5rem;
border: 1px solid #e2e8f0;
border-radius: 10px;
padding: 1rem;
}
.examples-section {
margin-top: 2rem;
padding: 1rem;
background: #f7fafc;
border-radius: 10px;
border: none;
}
.preview-section {
margin-top: 1rem;
padding: 1rem;
background: #ffffff;
border-radius: 10px;
}
/* Mobile Responsiveness */
@media (max-width: 768px) {
#col-container {
padding: 12px;
}
.main-title {
font-size: 1.5rem;
}
.subtitle {
font-size: 0.9rem;
}
.pattern-input {
font-size: 0.9rem;
}
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# 🎨 Professional Textile Pattern Generator
""",
elem_classes=["main-title"]
)
gr.Markdown(
"""
Create professional-grade, seamless patterns for textile manufacturing.
Design unique patterns with style and fabric texture controls,
perfect for commercial textile production and fashion design.
""",
elem_classes=["subtitle"]
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Text(
label="Pattern Description",
show_label=False,
max_lines=1,
placeholder="Describe your dream pattern (e.g., 'geometric Art Deco shapes in gold and navy')",
container=False,
elem_classes=["pattern-input"]
)
with gr.Column(scale=1):
style = gr.Dropdown(
choices=list(STYLE_OPTIONS.keys()),
label="Style",
value="None"
)
with gr.Column(scale=1):
fabric = gr.Dropdown(
choices=list(FABRIC_OPTIONS.keys()),
label="Fabric Texture",
value="None"
)
with gr.Column(scale=0.5):
run_button = gr.Button(
"✨ Generate",
elem_classes=["generate-button"]
)
with gr.Row():
with gr.Column():
pattern = gr.Image(
label="Generated Pattern",
show_label=True,
elem_classes=["result-image"]
)
with gr.Column():
preview = gr.Image(
label="Fabric Preview",
show_label=True,
elem_classes=["result-image"]
)
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
with gr.Group(elem_classes=["advanced-settings"]):
seed = gr.Slider(
label="Pattern Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Pattern",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
num_inference_steps = gr.Slider(
label="Generation Quality (Steps)",
minimum=1,
maximum=50,
step=1,
value=4,
)
with gr.Group(elem_classes=["examples-section"]):
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, style, fabric],
outputs=[pattern, preview, seed],
cache_examples=True
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, style, fabric, seed, randomize_seed, width, height, num_inference_steps],
outputs=[pattern, preview, seed]
)
demo.launch()