Compare-6 / app.py
Nymbo's picture
Update app.py
509cb06 verified
raw
history blame
7.64 kB
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image
import json
from threading import RLock
# Project by Nymbo
# Base API URL for Hugging Face inference
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
# Retrieve the API token from environment variables
API_TOKEN = os.getenv("HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
# Timeout for requests
timeout = 100
lock = RLock()
# Function to query the Hugging Face API for image generation
def query(prompt, model, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height):
# Debug log to indicate function start
print("Starting query function...")
# Print the parameters for debugging purposes
print(f"Prompt: {prompt}")
print(f"Model: {model}")
print(f"Parameters - Steps: {steps}, CFG Scale: {cfg_scale}, Seed: {seed}, Strength: {strength}, Width: {width}, Height: {height}")
# Check if the prompt is empty or None
if prompt == "" or prompt is None:
print("Prompt is empty or None. Exiting query function.") # Debug log
return None
# Randomly select an API token from available options to distribute the load
API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN"), os.getenv("HF_READ_TOKEN_2"), os.getenv("HF_READ_TOKEN_3"), os.getenv("HF_READ_TOKEN_4"), os.getenv("HF_READ_TOKEN_5")])
headers = {"Authorization": f"Bearer {API_TOKEN}"}
print(f"Selected API token: {API_TOKEN}") # Debug log
# Enhance the prompt with additional details for better quality
prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
print(f'Generation: {prompt}') # Debug log
# Set the API URL based on the selected model
if model == 'Stable Diffusion XL':
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
# Add more model options as needed
print(f"API URL set to: {API_URL}") # Debug log
# Define the payload for the request
payload = {
"inputs": prompt,
"negative_prompt": negative_prompt,
"steps": steps, # Number of sampling steps
"cfg_scale": cfg_scale, # Scale for controlling adherence to prompt
"seed": seed if seed != -1 else random.randint(1, 1000000000), # Random seed for reproducibility
"strength": strength, # How strongly the model should transform the image
"parameters": {
"width": width, # Width of the generated image
"height": height # Height of the generated image
}
}
print(f"Payload: {json.dumps(payload, indent=2)}") # Debug log
# Make a request to the API to generate the image
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
print(f"Response status code: {response.status_code}") # Debug log
except requests.exceptions.RequestException as e:
# Log any request exceptions and raise an error for the user
print(f"Request failed: {e}") # Debug log
raise gr.Error(f"Request failed: {e}")
# Check if the response status is not successful
if response.status_code != 200:
print(f"Error: Failed to retrieve image. Response status: {response.status_code}") # Debug log
print(f"Response content: {response.text}") # Debug log
if response.status_code == 400:
raise gr.Error(f"{response.status_code}: Bad Request - There might be an issue with the input parameters.")
elif response.status_code == 401:
raise gr.Error(f"{response.status_code}: Unauthorized - Please check your API token.")
elif response.status_code == 403:
raise gr.Error(f"{response.status_code}: Forbidden - You do not have permission to access this model.")
elif response.status_code == 404:
raise gr.Error(f"{response.status_code}: Not Found - The requested model could not be found.")
elif response.status_code == 503:
raise gr.Error(f"{response.status_code}: The model is being loaded. Please try again later.")
else:
raise gr.Error(f"{response.status_code}: An unexpected error occurred.")
try:
# Attempt to read the image from the response content
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
print(f'Generation completed! ({prompt})') # Debug log
return image
except Exception as e:
# Handle any errors that occur when opening the image
print(f"Error while trying to open image: {e}") # Debug log
return None
# Custom CSS to hide the footer in the interface
css = """
* {}
footer {visibility: hidden !important;}
"""
print("Initializing Gradio interface...") # Debug log
# Define the Gradio interface
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
# Tab for basic settings
with gr.Tab('Basic Settings'):
txt_input = gr.Textbox(label='Your prompt:', lines=4)
model = gr.Radio(label="Select a model", value="Stable Diffusion XL", choices=["Stable Diffusion XL", "Stable Diffusion 3", "FLUX.1 [Schnell]", "RealVisXL v4.0", "Duchaiten Real3D NSFW XL", "Tempest v0.1"], interactive=True)
gen_button = gr.Button('Generate Image')
# Tab for advanced settings
with gr.Tab("Advanced Settings"):
with gr.Row():
# Textbox for specifying elements to exclude from the image
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input")
with gr.Row():
# Slider for selecting the image width
width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
# Slider for selecting the image height
height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
with gr.Row():
# Slider for setting the number of sampling steps
steps = gr.Slider(label="Sampling steps", value=35, minimum=1, maximum=100, step=1)
with gr.Row():
# Slider for adjusting the CFG scale (guidance scale)
cfg = gr.Slider(label="CFG Scale", value=7, minimum=1, maximum=20, step=1)
with gr.Row():
# Slider for adjusting the transformation strength
strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.001)
with gr.Row():
# Slider for setting the seed for reproducibility
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1)
with gr.Row():
# Radio buttons for selecting the sampling method
method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
# Set up button click event to call the query function
gen_button.click(query, inputs=[txt_input, model, negative_prompt, steps, cfg, method, seed, strength, width, height], outputs=gr.Image(type="pil", label="Generated Image"))
print("Launching Gradio interface...") # Debug log
# Launch the Gradio interface without showing the API or sharing externally
demo.launch(show_api=False, max_threads=400)