mrm8488's picture
Update app.py
df59f0e
raw
history blame
4.17 kB
import gradio as gr
import os
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from PIL import Image
from styles import css, header_html, footer_html
from examples import examples
from transformers import pipeline
ars_model = pipeline("automatic-speech-recognition")
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
pipe = StableDiffusionPipeline.from_pretrained(
model_id, use_auth_token=os.environ.get('auth_token'), revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
def transcribe(audio):
text = ars_model(audio)["text"]
return text
def infer(audio, samples, steps, scale, seed):
prompt = transcribe(audio)
generator = torch.Generator(device=device).manual_seed(seed)
# If you are running locally with CPU, you can remove the `with autocast("cuda")`
if device == "cuda":
with autocast("cuda"):
images_list = pipe(
[prompt] * samples,
num_inference_steps=steps,
guidance_scale=scale,
generator=generator,
)
else:
images_list = pipe(
[prompt] * samples,
num_inference_steps=steps,
guidance_scale=scale,
generator=generator,
)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if(images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
block = gr.Blocks(css=css)
with block:
gr.HTML(header_html)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
audio = gr.Audio(
label="Describe a prompt",
source="microphone",
type="filepath"
# ).style(
# border=(True, False, True, True),
# rounded=(True, False, False, True),
# container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
with gr.Row(elem_id="advanced-options"):
samples = gr.Slider(label="Images", minimum=1,
maximum=4, value=4, step=1)
steps = gr.Slider(label="Steps", minimum=1,
maximum=50, value=45, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
# ex = gr.Examples(fn=infer, inputs=[
# audio, samples, steps, scale, seed], outputs=gallery)
# ex.dataset.headers = [""]
# audio.submit(infer, inputs=[audio, samples,
# steps, scale, seed], outputs=gallery)
btn.click(infer, inputs=[audio, samples, steps,
scale, seed], outputs=gallery)
advanced_button.click(
None,
[],
audio,
_js="""
() => {
const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
}""",
)
gr.HTML(footer_html)
block.queue(max_size=25).launch()