import os import random import numpy as np import gradio as gr import spaces import torch import supervision as sv from PIL import Image from typing import Optional, Tuple from diffusers import FluxInpaintPipeline from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference # Set up device and environment DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") HF_TOKEN = os.environ.get("HF_TOKEN", None) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Load models FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) FLUX_PIPE = FluxInpaintPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE) # Set up CUDA optimizations torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def resize_image_dimensions( original_resolution_wh: Tuple[int, int], maximum_dimension: int = 2048 ) -> Tuple[int, int]: width, height = original_resolution_wh if width <= maximum_dimension and height <= maximum_dimension: width = width - (width % 32) height = height - (height % 32) return width, height if width > height: scaling_factor = maximum_dimension / width else: scaling_factor = maximum_dimension / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) new_width = new_width - (new_width % 32) new_height = new_height - (new_height % 32) return new_width, new_height @spaces.GPU(duration=150) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image( image_input, segmentation_text, inpaint_text, seed_slicer: int, randomize_seed: bool, strength: float, num_inference_steps: int, progress=gr.Progress(track_tqdm=True) ) -> Optional[Image.Image]: if not image_input: gr.Info("Please upload an image.") return None, None if not segmentation_text: gr.Info("Please enter a text prompt for segmentation.") return None, None if not inpaint_text: gr.Info("Please enter a text prompt for inpainting.") return None, None # Florence-SAM segmentation _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=segmentation_text ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None, None mask = Image.fromarray(detections.mask[0].astype("uint8") * 255) # Resize images for FLUX width, height = resize_image_dimensions(original_resolution_wh=image_input.size) resized_image = image_input.resize((width, height), Image.LANCZOS) resized_mask = mask.resize((width, height), Image.NEAREST) # FLUX inpainting if randomize_seed: seed_slicer = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed_slicer) result = FLUX_PIPE( prompt=inpaint_text, image=resized_image, mask_image=resized_mask, width=width, height=height, strength=strength, generator=generator, num_inference_steps=num_inference_steps ).images[0] return result, resized_mask # Gradio interface with gr.Blocks() as demo: gr.Markdown("# MonsterAPI Prompt Guided Inpainting") with gr.Row(): with gr.Column(): image_input = gr.Image( label='Upload image', type='pil', image_mode='RGB', ) segmentation_text = gr.Textbox( label='Segmentation text prompt', placeholder='Enter text for segmentation' ) inpaint_text = gr.Textbox( label='Inpainting text prompt', placeholder='Enter text for inpainting' ) with gr.Accordion("Advanced Settings", open=False): seed_slicer = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) strength = gr.Slider( label="Strength", minimum=0, maximum=1, step=0.01, value=0.75, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=20, ) submit_button = gr.Button(value='Process', variant='primary') with gr.Column(): output_image = gr.Image(label='Output image') with gr.Accordion("Generated Mask", open=False): output_mask = gr.Image(label='Segmentation mask') submit_button.click( fn=process_image, inputs=[ image_input, segmentation_text, inpaint_text, seed_slicer, randomize_seed, strength, num_inference_steps ], outputs=[output_image, output_mask] ) demo.launch(debug=True, show_error=True, server_name="0.0.0.0",share=True)