# app_api.py import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import base64 from io import BytesIO import random import numpy as np import cv2 from PIL import Image import torch from segment_anything import SamPredictor, sam_model_registry from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional, List from celery import states from celery_app import celery_app from celery.result import AsyncResult import requests torch.set_num_threads(1) # Initialize FastAPI app = FastAPI() # Device configuration function def get_device(): if torch.backends.mps.is_available(): return "mps" elif torch.cuda.is_available(): return "cuda" else: return "cpu" # Device configuration DEVICE = get_device() # Initialize models at the module level # Load models mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE) mobile_sam.eval() mobile_predictor = SamPredictor(mobile_sam) # Load BrushNet and Stable Diffusion pipeline base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # Update to your base model path brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # Update to your BrushNet checkpoint path brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, device=DEVICE ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload(device=DEVICE) # Pydantic models for request and response validation class InpaintingRequest(BaseModel): input_image: str # Base64-encoded image input_mask: Optional[str] = None # Base64-encoded mask image prompt: str negative_prompt: Optional[str] = 'ugly, low quality' control_strength: float = 1.0 guidance_scale: float = 12.0 num_inference_steps: int = 50 seed: int = 0 # Set default seed to 0 randomize_seed: bool = False blended: bool = False invert_mask: bool = True count: int = 1 webhook_url: Optional[str] = None # URL to send the webhook to class InpaintingResponse(BaseModel): images: List[str] # Base64-encoded images # Function to resize images def resize_image(input_image, resolution): H, W, C = input_image.shape H = float(H) W = float(W) k = float(resolution) / min(H, W) H *= k W *= k H = int(np.round(H / 64.0)) * 64 W = int(np.round(W / 64.0)) * 64 img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) return img # Main processing function def process_api(request: InpaintingRequest): global mobile_predictor, pipe, DEVICE # Decode the base64-encoded input image input_image_data = base64.b64decode(request.input_image) original_image = Image.open(BytesIO(input_image_data)).convert("RGB") original_image_np = np.array(original_image) # Resize the original image if its smallest dimension is greater than 512 pixels if min(original_image_np.shape[:2]) > 512: original_image_np = resize_image(original_image_np, 512) # Handle the mask image if provided if request.input_mask: input_mask_data = base64.b64decode(request.input_mask) input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB") input_mask_np = np.array(input_mask_image) else: raise ValueError("Input mask is required") # Prepare other parameters if request.randomize_seed: seed = random.randint(0, 2147483647) else: seed = request.seed # This will always be an int due to the default value generator = torch.Generator(DEVICE).manual_seed(seed) # Prepare mask and init images H, W = original_image_np.shape[:2] original_mask = cv2.resize(input_mask_np, (W, H)) if request.invert_mask: original_mask = 255 - original_mask mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis] masked_image = original_image_np * (1 - mask) init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB") mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB") # Generate images images = pipe( [request.prompt] * request.count, init_image, mask_image, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, generator=generator, brushnet_conditioning_scale=float(request.control_strength), negative_prompt=[request.negative_prompt] * request.count, ).images # Blended images if requested if request.blended: if request.control_strength < 1.0: raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed') blended_images = [] mask_np = np.array(mask_image) / 255.0 mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0) mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred) for img in images: img_np = np.array(img) image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined image_pasted = image_pasted.astype(np.uint8) blended_images.append(Image.fromarray(image_pasted)) images = blended_images # Encode images to base64 output_images = [] for img in images: buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') output_images.append(img_str) return output_images # Celery task function @celery_app.task(bind=True) def generate_image_task(self, request_data): try: # Convert request_data to InpaintingRequest object request = InpaintingRequest(**request_data) # Call the processing function output_images = process_api(request) # If webhook URL is provided, send the images to the webhook webhook_url = request.webhook_url if webhook_url: payload = {'images': output_images} requests.post(webhook_url, json=payload) return output_images except Exception as e: # Optionally, log the error print(f"Exception in generate_image_task: {e}") # Re-raise the exception to let Celery handle it raise # API endpoint to start the image generation task @app.post("/inpaint") def inpaint(request: InpaintingRequest): try: task = generate_image_task.delay(request.dict()) return {"job_id": task.id} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # API endpoint to check task status @app.get("/status/{job_id}") def get_status(job_id: str): task = AsyncResult(job_id, app=celery_app) if task.state == states.PENDING: return {"status": "PENDING"} elif task.state == states.STARTED: return {"status": "STARTED"} elif task.state == states.SUCCESS: return {"status": "SUCCESS"} elif task.state == states.FAILURE: # Extract the error message error_info = task.info error_message = str(error_info) return {"status": "FAILURE", "error": error_message} else: return {"status": str(task.state)} # API endpoint to retrieve results @app.get("/result/{job_id}") def get_result(job_id: str): task = AsyncResult(job_id, app=celery_app) if task.state == states.SUCCESS: return {"images": task.result} elif task.state == states.FAILURE: raise HTTPException(status_code=500, detail=str(task.info)) else: raise HTTPException(status_code=202, detail="Task not completed yet.") # Run the app if __name__ == "__main__": import uvicorn uvicorn.run("app_api:app", host="0.0.0.0", port=8000)