StableDiffusion / app.py
Pranav Pandey
Add application file
22a3fc5
raw
history blame
2.93 kB
import streamlit as st
import time
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image
import re
st.title("Text-to-Image generation using Stable Diffusion")
st.subheader("Text Prompt")
text_prompt = st.text_area('Enter here:', height=100)
sl1, sl2, sl3, sl4 = st.columns(4)
num_samples = sl1.slider('Number of Images', 1, 4, 1)
num_steps = sl2.slider('Diffusion steps', 10, 150, 10)
scale = sl3.slider('Configuration scale', 0, 20, 10)
seed = sl4.number_input("Enter seed", 0, 25000, 47, 1)
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
word_list_dataset = load_dataset(
"stabilityai/word-list", data_files="list.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']
def infer(prompt, samples, steps, scale, seed):
for filter in word_list:
if re.search(rf"\b{filter}\b", prompt):
raise Exception(
"Unsafe content found. Please try again with different prompts.")
generator = torch.Generator(device=device).manual_seed(seed)
with autocast("cuda"):
images_list = pipe(
[prompt] * samples,
num_inference_steps=steps,
guidance_scale=scale,
generator=generator,
)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if (images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
def check_and_infer():
if len(text_prompt) < 5:
st.write("Prompt too small, enter some more detail")
st.experimental_rerun()
else:
with st.spinner('Wait for it...'):
generated_images = infer(
text_prompt, num_samples, num_steps, scale, seed)
for image in generated_images:
st.image(image, caption=text_prompt)
st.success('Image generated!')
st.balloons()
button_clicked = st.button(
"Generate Image", on_click=check_and_infer, disabled=False)
st.markdown("""---""")
col1, col2, col3 = st.columns([1, 6, 1])
with col1:
col1.write("")
with col2:
placeholder = col2.empty()
placeholder.image("pl2.png")
with col3:
col1.write("")
for image in []:
st.image(image, caption=text_prompt)
st.markdown("""---""")
st.text("Number of Images: Number of samples(Images) to generate")
st.text("Diffusion steps: How many steps to spend generating (diffusing) your image.")
st.text("Configuration scale: Scale adjusts how close the image will be to your prompt. Higher values keep your image closer to your prompt.")
st.text("Enter seed: Seed value to use for the model.")