File size: 4,063 Bytes
2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 24dc687 2ec881e 62b87fe 2ec881e 62b87fe 24dc687 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 2ec881e 62b87fe 24dc687 62b87fe 2ec881e 62b87fe 2ec881e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import numpy as np
import random
import torch
from diffusers import DiffusionPipeline
# === Configuration ===
MODEL_REPO_ID = "stabilityai/sdxl-turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def get_torch_dtype():
return torch.float16 if torch.cuda.is_available() else torch.float32
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
# === Lazy load the diffusion model ===
def get_pipe():
if not hasattr(get_pipe, "pipe"):
pipe = DiffusionPipeline.from_pretrained(MODEL_REPO_ID, torch_dtype=get_torch_dtype()).to(get_device())
get_pipe.pipe = pipe
return get_pipe.pipe
# === Define custom prompt builder ===
def build_prompt(word):
return (
f"Create a powerful, emotionally resonant image that vividly illustrates the meaning of the word '{word}', "
f"so that even someone who doesn’t speak English can understand it instantly. "
f"The visual should be sharp, symbolic, and universally relatable. "
f"Seamlessly weave the word '{word}' into the scene—clearly spelled but not overpowering—"
f"so it supports the concept without drawing attention away. "
f"Format: 1080x1080 pixels (square) for Instagram in a (.png) format."
)
# === Image generation function ===
def generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps, width, height, seed):
generator = torch.Generator().manual_seed(seed)
with torch.inference_mode():
return get_pipe()(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
# === Inference wrapper ===
def infer(
word,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
prompt = build_prompt(word)
image = generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps, width, height, seed)
return image, seed
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Word-to-Image Generator for Instagram 🎨")
with gr.Row():
word = gr.Text(
label="Vocabulary Word",
show_label=False,
max_lines=1,
placeholder="Enter a vocabulary word",
container=False,
)
run_button = gr.Button("Generate Image", scale=0, variant="primary")
result = gr.Image(label="Generated Image", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1080)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1080)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=3.5)
num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=4)
run_button.click(
fn=infer,
inputs=[word, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()
|