Spaces:
Running
on
Zero
Running
on
Zero
# utils/ai_generator.py | |
import gradio as gr | |
import os | |
import time | |
#from turtle import width # Added for implementing delays | |
from torch import cuda | |
import random | |
from utils.ai_generator_diffusers_flux import generate_ai_image_local | |
#from pathlib import Path | |
from huggingface_hub import InferenceClient | |
import requests | |
import io | |
from PIL import Image | |
from tempfile import NamedTemporaryFile | |
import utils.constants as constants | |
def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)): | |
# Initialize the InferenceClient | |
client = InferenceClient() | |
# Generate the image from the text | |
response = client(text, model_name) | |
# Get the image data | |
image_data = response.content | |
# Load the image from the data | |
image = Image.open(io.BytesIO(image_data)) | |
# Resize the image | |
image = image.resize((image_width, image_height)) | |
return image | |
def generate_ai_image( | |
map_option, | |
prompt_textbox_value, | |
neg_prompt_textbox_value, | |
model, | |
lora_weights=None, | |
conditioned_image=None, | |
pipeline = "FluxPipeline", | |
width=912, | |
height=512, | |
strength=0.5, | |
seed = 0, | |
progress=gr.Progress(track_tqdm=True), | |
*args, | |
**kwargs | |
): | |
if seed == 0: | |
seed = random.randint(0, constants.MAX_SEED) | |
if (cuda.is_available() and cuda.device_count() >= 1): # Check if a local GPU is available | |
print("Local GPU available. Generating image locally.") | |
if conditioned_image is not None: | |
pipeline = "FluxImg2ImgPipeline" | |
return generate_ai_image_local( | |
map_option, | |
prompt_textbox_value, | |
neg_prompt_textbox_value, | |
model, | |
lora_weights=lora_weights, | |
seed=seed, | |
conditioned_image=conditioned_image, | |
pipeline_name=pipeline, | |
strength=strength, | |
height=height, | |
width=width | |
) | |
else: | |
print("No local GPU available. Sending request to Hugging Face API.") | |
return generate_ai_image_remote( | |
map_option, | |
prompt_textbox_value, | |
neg_prompt_textbox_value, | |
model, | |
height=height, | |
width=width, | |
seed=seed | |
) | |
def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)): | |
max_retries = 3 | |
retry_delay = 4 # Initial delay in seconds | |
try: | |
if map_option != "Prompt": | |
prompt = constants.PROMPTS[map_option] | |
# Convert the negative prompt string to a list | |
negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "") | |
negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()] | |
else: | |
prompt = prompt_textbox_value | |
# Convert the negative prompt string to a list | |
negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else [] | |
print("Remotely Generating image with the following parameters:") | |
print(f"Prompt: {prompt}") | |
print(f"Negative Prompt: {negative_prompt}") | |
print(f"Height: {height}") | |
print(f"Width: {width}") | |
print(f"Number of Inference Steps: {num_inference_steps}") | |
print(f"Guidance Scale: {guidance_scale}") | |
print(f"Seed: {seed}") | |
for attempt in range(1, max_retries + 1): | |
try: | |
if os.getenv("IS_SHARED_SPACE") == "True": | |
client = InferenceClient( | |
model, | |
token=constants.HF_API_TOKEN | |
) | |
image = client.text_to_image( | |
inputs=prompt, | |
parameters={ | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"width": width, | |
"height": height, | |
"max_sequence_length":512, | |
# Optional: Add 'scheduler' and 'seed' if needed | |
"seed": seed | |
} | |
) | |
else: | |
API_URL = f"https://api-inference.huggingface.co/models/{model}" | |
headers = { | |
"Authorization": f"Bearer {constants.HF_API_TOKEN}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"width": width, | |
"height": height, | |
"max_sequence_length":512, | |
# Optional: Add 'scheduler' and 'seed' if needed | |
"seed": seed | |
} | |
} | |
print(f"Attempt {attempt}: Sending POST request to Hugging Face API...") | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=300) # Increased timeout to 30 seconds | |
if response.status_code == 200: | |
image_bytes = response.content | |
image = Image.open(io.BytesIO(image_bytes)) | |
break # Exit the retry loop on success | |
elif response.status_code == 400: | |
# Handle 400 Bad Request specifically | |
print(f"Bad Request (400): {response.text}") | |
print("Check your request parameters and payload format.") | |
return None # Do not retry on 400 errors | |
elif response.status_code in [429, 504]: | |
print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...") | |
if attempt < max_retries: | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
else: | |
response.raise_for_status() # Raise exception after max retries | |
else: | |
print(f"Received unexpected status code {response.status_code}: {response.text}") | |
response.raise_for_status() | |
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error: | |
print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...") | |
if attempt < max_retries: | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
else: | |
raise # Re-raise the exception after max retries | |
except requests.exceptions.RequestException as req_error: | |
print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...") | |
if attempt < max_retries: | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
else: | |
raise # Re-raise the exception after max retries | |
else: | |
# If all retries failed | |
print("Max retries exceeded. Failed to generate image.") | |
return None | |
with NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
image.save(tmp.name, format="PNG") | |
constants.temp_files.append(tmp.name) | |
print(f"Image saved to {tmp.name}") | |
return tmp.name | |
except Exception as e: | |
print(f"Error generating AI image: {e}") | |
return None |