Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from diffusers import FluxPipeline, FluxTransformer2DModel | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from transformers import T5EncoderModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| import gc | |
| import random | |
| from PIL import Image | |
| import os | |
| import time | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| DEFAULT_HEIGHT = 1024 | |
| DEFAULT_WIDTH = 1024 | |
| DEFAULT_GUIDANCE_SCALE = 3.5 | |
| DEFAULT_NUM_INFERENCE_STEPS = 50 | |
| DEFAULT_MAX_SEQUENCE_LENGTH = 512 | |
| GENERATION_SEED = 0 # could use a random number generator to set this, for more variety | |
| def clear_gpu_memory(*args): | |
| allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0 | |
| reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0 | |
| print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB") | |
| deleted_types = [] | |
| for arg in args: | |
| if arg is not None: | |
| deleted_types.append(str(type(arg))) | |
| del arg | |
| if deleted_types: | |
| print(f"Deleted objects of types: {', '.join(deleted_types)}") | |
| else: | |
| print("No objects passed to clear_gpu_memory.") | |
| gc.collect() | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0 | |
| reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0 | |
| print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB") | |
| print("-" * 20) | |
| CACHED_PIPES = {} | |
| def load_bf16_pipeline(): | |
| """Loads the original FLUX.1-dev pipeline in BF16 precision.""" | |
| print("Loading BF16 pipeline...") | |
| MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| if MODEL_ID in CACHED_PIPES: | |
| return CACHED_PIPES[MODEL_ID] | |
| start_time = time.time() | |
| try: | |
| pipe = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to(DEVICE) | |
| # pipe.enable_model_cpu_offload() | |
| end_time = time.time() | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
| # CACHED_PIPES[MODEL_ID] = pipe | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading BF16 pipeline: {e}") | |
| raise # Re-raise exception to be caught in generate_images | |
| def load_bnb_8bit_pipeline(): | |
| """Loads the FLUX.1-dev pipeline with 8-bit quantized components.""" | |
| print("Loading 8-bit BNB pipeline...") | |
| MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit" | |
| if MODEL_ID in CACHED_PIPES: | |
| return CACHED_PIPES[MODEL_ID] | |
| start_time = time.time() | |
| try: | |
| pipe = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to(DEVICE) | |
| # pipe.enable_model_cpu_offload() | |
| end_time = time.time() | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
| CACHED_PIPES[MODEL_ID] = pipe | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading 8-bit BNB pipeline: {e}") | |
| raise | |
| def load_bnb_4bit_pipeline(): | |
| """Loads the FLUX.1-dev pipeline with 4-bit quantized components.""" | |
| print("Loading 4-bit BNB pipeline...") | |
| MODEL_ID = "derekl35/FLUX.1-dev-nf4" | |
| if MODEL_ID in CACHED_PIPES: | |
| return CACHED_PIPES[MODEL_ID] | |
| start_time = time.time() | |
| try: | |
| pipe = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to(DEVICE) | |
| # pipe.enable_model_cpu_offload() | |
| end_time = time.time() | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
| CACHED_PIPES[MODEL_ID] = pipe | |
| return pipe | |
| except Exception as e: | |
| print(f"4-bit BNB pipeline: {e}") | |
| raise | |
| # --- Image Generation and Shuffling Function --- | |
| def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)): | |
| """Loads original and selected quantized model, generates one image each, clears memory, shuffles results.""" | |
| if not prompt: | |
| return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None) | |
| if not quantization_choice: | |
| # Return updates for all outputs to clear them or show warning | |
| return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None) | |
| # Determine which quantized model to load | |
| if quantization_choice == "8-bit": | |
| quantized_load_func = load_bnb_8bit_pipeline | |
| quantized_label = "Quantized (8-bit)" | |
| elif quantization_choice == "4-bit": | |
| quantized_load_func = load_bnb_4bit_pipeline | |
| quantized_label = "Quantized (4-bit)" | |
| else: | |
| # Should not happen with Radio choices, but good practice | |
| return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None) | |
| model_configs = [ | |
| ("Original", load_bf16_pipeline), | |
| (quantized_label, quantized_load_func), # Use the specific label here | |
| ] | |
| results = [] | |
| pipe_kwargs = { | |
| "prompt": prompt, | |
| "height": DEFAULT_HEIGHT, | |
| "width": DEFAULT_WIDTH, | |
| "guidance_scale": DEFAULT_GUIDANCE_SCALE, | |
| "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, | |
| "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, | |
| } | |
| current_pipe = None # Keep track of the current pipe for cleanup | |
| for i, (label, load_func) in enumerate(model_configs): | |
| progress(i / len(model_configs), desc=f"Loading {label} model...") | |
| print(f"\n--- Loading {label} Model ---") | |
| load_start_time = time.time() | |
| try: | |
| # Ensure previous pipe is cleared *before* loading the next | |
| # if current_pipe: | |
| # print(f"--- Clearing memory before loading {label} Model ---") | |
| # clear_gpu_memory(current_pipe) | |
| # current_pipe = None | |
| current_pipe = load_func() | |
| load_end_time = time.time() | |
| print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.") | |
| progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...") | |
| print(f"--- Generating with {label} Model ---") | |
| gen_start_time = time.time() | |
| image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(GENERATION_SEED)).images | |
| image = image_list[0] | |
| gen_end_time = time.time() | |
| results.append({"label": label, "image": image}) | |
| print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---") | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"Memory reserved: {mem_reserved:.2f} GB") | |
| except Exception as e: | |
| print(f"Error during {label} model processing: {e}") | |
| # Attempt cleanup | |
| if current_pipe: | |
| print(f"--- Clearing memory after error with {label} Model ---") | |
| clear_gpu_memory(current_pipe) | |
| current_pipe = None | |
| # Return error state to Gradio - update all outputs | |
| return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None) | |
| # No finally block needed here, cleanup happens before next load or after loop | |
| # Final cleanup after the loop finishes successfully | |
| # if current_pipe: | |
| # print(f"--- Clearing memory after last model ({label}) ---") | |
| # clear_gpu_memory(current_pipe) | |
| # current_pipe = None | |
| if len(results) != len(model_configs): | |
| print("Generation did not complete for all models.") | |
| # Update all outputs | |
| return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None) | |
| # Shuffle the results for display | |
| shuffled_results = results.copy() | |
| random.shuffle(shuffled_results) | |
| # Create the gallery data: [(image, caption), (image, caption)] | |
| shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)] | |
| # Create the mapping: display_index -> correct_label (e.g., {0: 'Original', 1: 'Quantized (8-bit)'}) | |
| correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)} | |
| print("Correct mapping (hidden):", correct_mapping) | |
| guess_radio_update = gr.update(choices=["Image 1", "Image 2"], value=None, interactive=True) | |
| # Return shuffled images, the correct mapping state, status message, and update the guess radio | |
| return shuffled_data_for_gallery, correct_mapping, gr.update(value="Generation complete! Make your guess.", interactive=False), guess_radio_update | |
| # --- Guess Verification Function --- | |
| def check_guess(user_guess, correct_mapping_state): | |
| """Compares the user's guess with the correct mapping stored in the state.""" | |
| if not isinstance(correct_mapping_state, dict) or not correct_mapping_state: | |
| return "Please generate images first (state is empty or invalid)." | |
| if user_guess is None: | |
| return "Please select which image you think is quantized." | |
| # Find which display index (0 or 1) corresponds to the quantized image | |
| quantized_image_index = -1 | |
| quantized_label_actual = "" | |
| for index, label in correct_mapping_state.items(): | |
| if "Quantized" in label: # Check if the label indicates quantization | |
| quantized_image_index = index | |
| quantized_label_actual = label # Store the full label e.g. "Quantized (8-bit)" | |
| break | |
| if quantized_image_index == -1: | |
| # This shouldn't happen if generation was successful | |
| return "Error: Could not find the quantized image in the mapping data." | |
| # Determine what the user *should* have selected based on the index | |
| correct_guess_label = f"Image {quantized_image_index + 1}" # "Image 1" or "Image 2" | |
| if user_guess == correct_guess_label: | |
| feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model." | |
| else: | |
| feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}." | |
| return feedback | |
| with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# FLUX Model Quantization Challenge") | |
| gr.Markdown( | |
| "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). " | |
| "Enter a prompt, choose the quantization method, and generate two images. " | |
| "The images will be shuffled. Can you guess which one used quantization?" | |
| ) | |
| with gr.Row(): | |
| prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic portrait of an astronaut on Mars", scale=3) | |
| quantization_choice_radio = gr.Radio( | |
| choices=["8-bit", "4-bit"], | |
| label="Select Quantization", | |
| value="8-bit", # Default choice | |
| scale=1 | |
| ) | |
| generate_button = gr.Button("Generate & Compare", variant="primary", scale=1) | |
| output_gallery = gr.Gallery( | |
| label="Generated Images (Original vs. Quantized)", | |
| columns=2, | |
| height=512, | |
| object_fit="contain", | |
| allow_preview=True, | |
| show_label=True, # Shows "Image 1", "Image 2" captions we provide | |
| ) | |
| gr.Markdown("### Which image used the selected quantization method?") | |
| with gr.Row(): | |
| # Centered guess radio and submit button | |
| with gr.Column(scale=1): # Dummy column for spacing | |
| pass | |
| with gr.Column(scale=2): # Column for the radio button | |
| guess_radio = gr.Radio( | |
| choices=[], | |
| label="Your Guess", | |
| info="Select the image you believe was generated with the quantized model.", | |
| interactive=False # Disabled until images are generated | |
| ) | |
| with gr.Column(scale=1): # Column for the button | |
| submit_guess_button = gr.Button("Submit Guess") | |
| with gr.Column(scale=1): # Dummy column for spacing | |
| pass | |
| feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1) | |
| # Hidden state to store the correct mapping after shuffling | |
| # e.g., {0: 'Original', 1: 'Quantized (8-bit)'} or {0: 'Quantized (4-bit)', 1: 'Original'} | |
| correct_mapping_state = gr.State({}) | |
| generate_button.click( | |
| fn=generate_images, | |
| inputs=[prompt_input, quantization_choice_radio], | |
| outputs=[output_gallery, correct_mapping_state, feedback_box, guess_radio] | |
| ).then( | |
| lambda: "", # Clear feedback box on new generation | |
| outputs=[feedback_box] | |
| ) | |
| submit_guess_button.click( | |
| fn=check_guess, | |
| inputs=[guess_radio, correct_mapping_state], # Pass the selected guess and the state | |
| outputs=[feedback_box] | |
| ) | |
| if __name__ == "__main__": | |
| # queue() | |
| # demo.queue().launch() # Set share=True to create public link if needed | |
| demo.launch() |