Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,073 Bytes
e04ce6b 1dcd95e e04ce6b 1dcd95e e0358af 1dcd95e e04ce6b 1dcd95e e04ce6b 1dcd95e e04ce6b 1dcd95e e0358af e04ce6b 1dcd95e e04ce6b da3da8b 94bd8c8 e04ce6b 4167141 1dcd95e b604e8c 1dcd95e 6b8f8c9 1dcd95e 3ed9f62 1dcd95e b604e8c 1dcd95e b604e8c 1dcd95e b604e8c 1dcd95e 6b8f8c9 1dcd95e 0000d2f 1dcd95e e04ce6b 1dcd95e e04ce6b 1dcd95e e0358af 1dcd95e b604e8c 1dcd95e 11c664b 1dcd95e e04ce6b 1dcd95e 7af58fc 94bd8c8 1dcd95e e04ce6b 7af58fc 94bd8c8 da3da8b 94bd8c8 da3da8b 94bd8c8 da3da8b 94bd8c8 da3da8b 1dcd95e 6b8f8c9 1dcd95e 94bd8c8 1dcd95e 94bd8c8 1dcd95e 94bd8c8 1dcd95e e04ce6b 1dcd95e e04ce6b 1dcd95e e04ce6b 94bd8c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, pipeline
from diffusers import DiffusionPipeline
import random
import numpy as np
import os
from qwen_vl_utils import process_vision_info
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# FLUX.1-dev model
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token=huggingface_token
).to(device)
# Initialize Qwen2VL model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"prithivMLmods/JSONify-Flux", trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
qwen_processor = AutoProcessor.from_pretrained("prithivMLmods/JSONify-Flux", trust_remote_code=True)
# Prompt Enhancer
enhancer_long = pipeline("summarization", model="prithivMLmods/t5-Flan-Prompt-Enhance", device=device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Qwen2VL caption function – updated with no_grad and autocast contexts, and explicit device moves
@spaces.GPU
def qwen_caption(image):
# Convert image to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Generate a detailed and optimized caption for the given image."},
],
}
]
text = qwen_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Explicitly move each tensor to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Wrap generation in no_grad and autocast contexts to prevent extra memory usage and potential caching issues
with torch.no_grad():
with torch.cuda.amp.autocast(device_type="cuda", dtype=torch.float16):
generated_ids = qwen_model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
# Prompt Enhancer function (unchanged)
def enhance_prompt(input_prompt):
result = enhancer_long("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
return enhanced_text
@spaces.GPU
def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
if image is not None:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
prompt = qwen_caption(image)
print(prompt)
else:
prompt = text_prompt
if use_enhancer:
prompt = enhance_prompt(prompt)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
torch.cuda.empty_cache()
try:
image = pipe(
prompt=prompt,
generator=generator,
num_inference_steps=num_inference_steps,
width=width,
height=height,
guidance_scale=guidance_scale
).images[0]
except RuntimeError as e:
if "CUDA out of memory" in str(e):
raise RuntimeError("CUDA out of memory. Try reducing image size or inference steps.")
else:
raise e
return image, prompt, seed
custom_css = """
.input-group, .output-group {
/* You can add styling here if needed */
}
.submit-btn {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
border: none !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
title = """<h1 align="center">FLUX.1-dev with Qwen2VL Captioner and Prompt Enhancer</h1>
<p><center>
<a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
<a href="https://huggingface.co/prithivMLmods/JSONify-Flux" target="_blank">[JSONify Flux Model]</a>
<a href="https://huggingface.co/prithivMLmods/t5-Flan-Prompt-Enhance" target="_blank">[Prompt Enhancer t5]</a>
<p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
</center></p>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(title)
with gr.Sidebar(label="Parameters", open=True):
gr.Markdown(
"""
### About
#### Flux.1-Dev
FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. FLUX.1 [dev] is an open-weight, guidance-distilled model for non-commercial applications. Directly distilled from FLUX.1 [pro], FLUX.1 [dev] obtains similar quality and prompt adherence capabilities, while being more efficient than a standard model of the same size.
[FLUX.1-dev Model](https://huggingface.co/black-forest-labs/FLUX.1-dev)
#### JSONify-Flux
JSONify-Flux is a multimodal image-text-text model trained on a dataset of FLUX-generated images with context-rich captions based on the Qwen2VL architecture. The JSON-based instruction has been manually removed to avoid JSON format captions.
[JSONify-Flux Model](https://huggingface.co/prithivMLmods/JSONify-Flux)
#### t5-Flan-Prompt-Enhance
t5-Flan-Prompt-Enhance is a prompt summarization model that enriches synthetic FLUX prompts with more detailed descriptions.
[t5-Flan-Prompt-Enhance Model](https://huggingface.co/prithivMLmods/t5-Flan-Prompt-Enhance)
"""
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="input-group"):
input_image = gr.Image(label="Input Image (Qwen2VL Captioner)")
with gr.Accordion("Advanced Settings", open=False):
text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=32)
generate_btn = gr.Button("Generate Image & Prompt", elem_classes="submit-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-group"):
output_image = gr.Image(label="result", elem_id="gallery", show_label=False)
final_prompt = gr.Textbox(label="prompt")
used_seed = gr.Number(label="seed")
generate_btn.click(
fn=process_workflow,
inputs=[
input_image, text_prompt, use_enhancer, seed, randomize_seed,
width, height, guidance_scale, num_inference_steps
],
outputs=[output_image, final_prompt, used_seed]
)
demo.launch(debug=True) |