|
import gradio as gr |
|
from stability_sdk import client |
|
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation |
|
from PIL import Image |
|
|
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
radius_size=gr.themes.sizes.radius_sm, |
|
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
def infer(prompt, api_key): |
|
stability_api = client.StabilityInference( |
|
key=api_key, |
|
verbose=True, |
|
engine="stable-diffusion-xl-beta-v2-2-2", |
|
|
|
) |
|
answers = stability_api.generate( |
|
prompt=prompt, |
|
seed=992446758, |
|
|
|
|
|
steps=30, |
|
cfg_scale=8.0, |
|
|
|
|
|
width=512, |
|
height=512, |
|
samples=1, |
|
sampler=generation.SAMPLER_K_DPMPP_2M |
|
|
|
|
|
) |
|
for resp in answers: |
|
for artifact in resp.artifacts: |
|
if artifact.finish_reason == generation.FILTER: |
|
warnings.warn( |
|
"Your request activated the API's safety filters and could not be processed." |
|
"Please modify the prompt and try again.") |
|
if artifact.type == generation.ARTIFACT_IMAGE: |
|
img = Image.open(io.BytesIO(artifact.binary)) |
|
|
|
return img |
|
|
|
with gr.Blocks(theme = theme) as demo: |
|
gr.Markdown("# Stable Diffusion XL") |
|
|
|
api_key_input = gr.Textbox(type = "password", label = "Enter your StabilityAI API key here") |
|
text = gr.Textbox(label="Enter your prompt", |
|
show_label=True, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
elem_id="prompt-text-input", |
|
).style( |
|
border=(True, False, True, True), |
|
rounded=(True, False, False, True), |
|
container=False, |
|
) |
|
btn = gr.Button("Generate image").style( |
|
margin=False, |
|
rounded=(False, True, True, False), |
|
full_width=False, |
|
) |
|
|
|
gallery = gr.Gallery( |
|
label="Generated images", show_label=False, elem_id="gallery" |
|
).style(grid=[2], height="auto") |
|
|
|
btn.click(infer, inputs=[text, api_key_input], outputs=[gallery]) |
|
|
|
|
|
|
|
demo.launch() |