qr-art / app.py
aqlanhadi's picture
defined params for qr details
4b31ebb
raw
history blame
4.97 kB
import torch
import gradio as gr
from PIL import Image
import qrcode
from pathlib import Path
from multiprocessing import cpu_count
import requests
import io
import os
from PIL import Image
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetImg2ImgPipeline,
ControlNetModel,
DDIMScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
HeunDiscreteScheduler,
EulerDiscreteScheduler,
)
qrcode_generator = qrcode.QRCode(
version=1,
error_correction=qrcode.ERROR_CORRECT_H,
box_size=10,
border=4,
)
controlnet = ControlNetModel.from_pretrained(
"DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
safety_checker=None,
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
def resize_for_condition_image(input_image: Image.Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
SAMPLER_MAP = {
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
"DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
"Heun": lambda config: HeunDiscreteScheduler.from_config(config),
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
"DDIM": lambda config: DDIMScheduler.from_config(config),
"DEIS": lambda config: DEISMultistepScheduler.from_config(config),
}
def inference(
first_name: str = "John",
last_name: str = "Doe",
telephone_number: str = "+60123456789",
email_address: str = "[email protected]",
url: str = "https://example.com",
prompt: str = "Sky view of highly aesthetic, ancient greek thermal baths in beautiful nature",
negative_prompt: str = "ugly, disfigured, low quality, blurry, nsfw",
):
guidance_scale = 7.5
controlnet_conditioning_scale = 1.5
strength = 0.9
seed = -1
sampler = "DPM++ Karras SDE"
qrcode_image = None
qr_code_content = f"MECARD:N:{last_name},{first_name};TEL:{telephone_number};EMAIL:{email_address};URL:{url};"
if prompt is None or prompt == "":
raise gr.Error("Prompt is required")
if qr_code_content == "":
raise gr.Error("Content is required")
pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)
generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()
if qr_code_content != "" or qrcode_image.size == (1, 1):
print("Generating QR Code from content")
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=10,
border=4,
)
qr.add_data(qr_code_content)
qr.make(fit=True)
qrcode_image = qr.make_image(fill_color="black", back_color="white")
qrcode_image = resize_for_condition_image(qrcode_image, 768)
else:
print("Using QR Code Image")
qrcode_image = resize_for_condition_image(qrcode_image, 768)
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=qrcode_image,
control_image=qrcode_image, # type: ignore
width=768, # type: ignore
height=768, # type: ignore
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_conditioning_scale), # type: ignore
generator=generator,
strength=float(strength),
num_inference_steps=40,
)
return out.images[0] # type: ignore
# MECARD:N:Aqlan Nor Azman;TEL:60173063421;EMAIL:[email protected];
generator = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(
label="First Name",
value="John",
),
gr.Textbox(
label="Last Name",
value="Doe",
),
gr.Textbox(
label="Telephone Number",
value="+60123456789",
),
gr.Textbox(
label="Email Address",
value="[email protected]"
),
gr.Textbox(
label="URL",
value="https://example.com",
),
gr.Textbox(
label="Prompt",
value="Sky view of highly aesthetic, ancient greek thermal baths in beautiful nature",
),
gr.Textbox(
label="Negative Prompt",
value="ugly, disfigured, low quality, blurry, nsfw",
)
],
outputs="image"
)
if __name__ == "__main__":
generator.queue(concurrency_count=1, max_size=20)
generator.launch()