Spaces:
Running
Running
import gradio as gr | |
import requests | |
import io | |
import random | |
import os | |
import time | |
from PIL import Image | |
import json | |
from threading import RLock | |
# Project by Nymbo | |
# Base API URL for Hugging Face inference | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
# Retrieve the API token from environment variables | |
API_TOKEN = os.getenv("HF_READ_TOKEN") | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
# Timeout for requests | |
timeout = 100 | |
lock = RLock() | |
# Function to query the Hugging Face API for image generation | |
def query(prompt, model, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height): | |
# Debug log to indicate function start | |
print("Starting query function...") | |
# Print the parameters for debugging purposes | |
print(f"Prompt: {prompt}") | |
print(f"Model: {model}") | |
print(f"Parameters - Steps: {steps}, CFG Scale: {cfg_scale}, Seed: {seed}, Strength: {strength}, Width: {width}, Height: {height}") | |
# Check if the prompt is empty or None | |
if prompt == "" or prompt is None: | |
print("Prompt is empty or None. Exiting query function.") # Debug log | |
return None | |
# Randomly select an API token from available options to distribute the load | |
API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN"), os.getenv("HF_READ_TOKEN_2"), os.getenv("HF_READ_TOKEN_3"), os.getenv("HF_READ_TOKEN_4"), os.getenv("HF_READ_TOKEN_5")]) | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
print(f"Selected API token: {API_TOKEN}") # Debug log | |
# Enhance the prompt with additional details for better quality | |
prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect." | |
print(f'Generation: {prompt}') # Debug log | |
# Set the API URL based on the selected model | |
if model == 'Stable Diffusion XL': | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
# Add more model options as needed | |
print(f"API URL set to: {API_URL}") # Debug log | |
# Define the payload for the request | |
payload = { | |
"inputs": prompt, | |
"negative_prompt": negative_prompt, | |
"steps": steps, # Number of sampling steps | |
"cfg_scale": cfg_scale, # Scale for controlling adherence to prompt | |
"seed": seed if seed != -1 else random.randint(1, 1000000000), # Random seed for reproducibility | |
"strength": strength, # How strongly the model should transform the image | |
"parameters": { | |
"width": width, # Width of the generated image | |
"height": height # Height of the generated image | |
} | |
} | |
print(f"Payload: {json.dumps(payload, indent=2)}") # Debug log | |
# Make a request to the API to generate the image | |
try: | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout) | |
print(f"Response status code: {response.status_code}") # Debug log | |
except requests.exceptions.RequestException as e: | |
# Log any request exceptions and raise an error for the user | |
print(f"Request failed: {e}") # Debug log | |
raise gr.Error(f"Request failed: {e}") | |
# Check if the response status is not successful | |
if response.status_code != 200: | |
print(f"Error: Failed to retrieve image. Response status: {response.status_code}") # Debug log | |
print(f"Response content: {response.text}") # Debug log | |
if response.status_code == 400: | |
raise gr.Error(f"{response.status_code}: Bad Request - There might be an issue with the input parameters.") | |
elif response.status_code == 401: | |
raise gr.Error(f"{response.status_code}: Unauthorized - Please check your API token.") | |
elif response.status_code == 403: | |
raise gr.Error(f"{response.status_code}: Forbidden - You do not have permission to access this model.") | |
elif response.status_code == 404: | |
raise gr.Error(f"{response.status_code}: Not Found - The requested model could not be found.") | |
elif response.status_code == 503: | |
raise gr.Error(f"{response.status_code}: The model is being loaded. Please try again later.") | |
else: | |
raise gr.Error(f"{response.status_code}: An unexpected error occurred.") | |
try: | |
# Attempt to read the image from the response content | |
image_bytes = response.content | |
image = Image.open(io.BytesIO(image_bytes)) | |
print(f'Generation completed! ({prompt})') # Debug log | |
return image | |
except Exception as e: | |
# Handle any errors that occur when opening the image | |
print(f"Error while trying to open image: {e}") # Debug log | |
return None | |
# Custom CSS to hide the footer in the interface | |
css = """ | |
* {} | |
footer {visibility: hidden !important;} | |
""" | |
print("Initializing Gradio interface...") # Debug log | |
# Define the Gradio interface | |
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo: | |
# Tab for basic settings | |
with gr.Tab('Basic Settings'): | |
txt_input = gr.Textbox(label='Your prompt:', lines=4) | |
model = gr.Radio(label="Select a model", value="Stable Diffusion XL", choices=["Stable Diffusion XL", "Stable Diffusion 3", "FLUX.1 [Schnell]", "RealVisXL v4.0", "Duchaiten Real3D NSFW XL", "Tempest v0.1"], interactive=True) | |
gen_button = gr.Button('Generate Image') | |
# Tab for advanced settings | |
with gr.Tab("Advanced Settings"): | |
with gr.Row(): | |
# Textbox for specifying elements to exclude from the image | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input") | |
with gr.Row(): | |
# Slider for selecting the image width | |
width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32) | |
# Slider for selecting the image height | |
height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32) | |
with gr.Row(): | |
# Slider for setting the number of sampling steps | |
steps = gr.Slider(label="Sampling steps", value=35, minimum=1, maximum=100, step=1) | |
with gr.Row(): | |
# Slider for adjusting the CFG scale (guidance scale) | |
cfg = gr.Slider(label="CFG Scale", value=7, minimum=1, maximum=20, step=1) | |
with gr.Row(): | |
# Slider for adjusting the transformation strength | |
strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.001) | |
with gr.Row(): | |
# Slider for setting the seed for reproducibility | |
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1) | |
with gr.Row(): | |
# Radio buttons for selecting the sampling method | |
method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"]) | |
# Set up button click event to call the query function | |
gen_button.click(query, inputs=[txt_input, model, negative_prompt, steps, cfg, method, seed, strength, width, height], outputs=gr.Image(type="pil", label="Generated Image")) | |
print("Launching Gradio interface...") # Debug log | |
# Launch the Gradio interface without showing the API or sharing externally | |
demo.launch(show_api=False, max_threads=400) |