BrushNetApi / app_api.py
Juan Leal
Fixes MPS fallback error
6454648
# app_api.py
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import base64
from io import BytesIO
import random
import numpy as np
import cv2
from PIL import Image
import torch
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
from celery import states
from celery_app import celery_app
from celery.result import AsyncResult
import requests
torch.set_num_threads(1)
# Initialize FastAPI
app = FastAPI()
# Device configuration function
def get_device():
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
return "cpu"
# Device configuration
DEVICE = get_device()
# Initialize models at the module level
# Load models
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE)
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
# Load BrushNet and Stable Diffusion pipeline
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # Update to your base model path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # Update to your BrushNet checkpoint path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path,
brushnet=brushnet,
torch_dtype=torch.float16,
low_cpu_mem_usage=False,
device=DEVICE
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=DEVICE)
# Pydantic models for request and response validation
class InpaintingRequest(BaseModel):
input_image: str # Base64-encoded image
input_mask: Optional[str] = None # Base64-encoded mask image
prompt: str
negative_prompt: Optional[str] = 'ugly, low quality'
control_strength: float = 1.0
guidance_scale: float = 12.0
num_inference_steps: int = 50
seed: int = 0 # Set default seed to 0
randomize_seed: bool = False
blended: bool = False
invert_mask: bool = True
count: int = 1
webhook_url: Optional[str] = None # URL to send the webhook to
class InpaintingResponse(BaseModel):
images: List[str] # Base64-encoded images
# Function to resize images
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
# Main processing function
def process_api(request: InpaintingRequest):
global mobile_predictor, pipe, DEVICE
# Decode the base64-encoded input image
input_image_data = base64.b64decode(request.input_image)
original_image = Image.open(BytesIO(input_image_data)).convert("RGB")
original_image_np = np.array(original_image)
# Resize the original image if its smallest dimension is greater than 512 pixels
if min(original_image_np.shape[:2]) > 512:
original_image_np = resize_image(original_image_np, 512)
# Handle the mask image if provided
if request.input_mask:
input_mask_data = base64.b64decode(request.input_mask)
input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB")
input_mask_np = np.array(input_mask_image)
else:
raise ValueError("Input mask is required")
# Prepare other parameters
if request.randomize_seed:
seed = random.randint(0, 2147483647)
else:
seed = request.seed # This will always be an int due to the default value
generator = torch.Generator(DEVICE).manual_seed(seed)
# Prepare mask and init images
H, W = original_image_np.shape[:2]
original_mask = cv2.resize(input_mask_np, (W, H))
if request.invert_mask:
original_mask = 255 - original_mask
mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis]
masked_image = original_image_np * (1 - mask)
init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
# Generate images
images = pipe(
[request.prompt] * request.count,
init_image,
mask_image,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
generator=generator,
brushnet_conditioning_scale=float(request.control_strength),
negative_prompt=[request.negative_prompt] * request.count,
).images
# Blended images if requested
if request.blended:
if request.control_strength < 1.0:
raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed')
blended_images = []
mask_np = np.array(mask_image) / 255.0
mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0)
mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred)
for img in images:
img_np = np.array(img)
image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined
image_pasted = image_pasted.astype(np.uint8)
blended_images.append(Image.fromarray(image_pasted))
images = blended_images
# Encode images to base64
output_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
output_images.append(img_str)
return output_images
# Celery task function
@celery_app.task(bind=True)
def generate_image_task(self, request_data):
try:
# Convert request_data to InpaintingRequest object
request = InpaintingRequest(**request_data)
# Call the processing function
output_images = process_api(request)
# If webhook URL is provided, send the images to the webhook
webhook_url = request.webhook_url
if webhook_url:
payload = {'images': output_images}
requests.post(webhook_url, json=payload)
return output_images
except Exception as e:
# Optionally, log the error
print(f"Exception in generate_image_task: {e}")
# Re-raise the exception to let Celery handle it
raise
# API endpoint to start the image generation task
@app.post("/inpaint")
def inpaint(request: InpaintingRequest):
try:
task = generate_image_task.delay(request.dict())
return {"job_id": task.id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# API endpoint to check task status
@app.get("/status/{job_id}")
def get_status(job_id: str):
task = AsyncResult(job_id, app=celery_app)
if task.state == states.PENDING:
return {"status": "PENDING"}
elif task.state == states.STARTED:
return {"status": "STARTED"}
elif task.state == states.SUCCESS:
return {"status": "SUCCESS"}
elif task.state == states.FAILURE:
# Extract the error message
error_info = task.info
error_message = str(error_info)
return {"status": "FAILURE", "error": error_message}
else:
return {"status": str(task.state)}
# API endpoint to retrieve results
@app.get("/result/{job_id}")
def get_result(job_id: str):
task = AsyncResult(job_id, app=celery_app)
if task.state == states.SUCCESS:
return {"images": task.result}
elif task.state == states.FAILURE:
raise HTTPException(status_code=500, detail=str(task.info))
else:
raise HTTPException(status_code=202, detail="Task not completed yet.")
# Run the app
if __name__ == "__main__":
import uvicorn
uvicorn.run("app_api:app", host="0.0.0.0", port=8000)