Spaces:
Build error
Build error
# app_api.py | |
import os | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
import base64 | |
from io import BytesIO | |
import random | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from segment_anything import SamPredictor, sam_model_registry | |
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional, List | |
from celery import states | |
from celery_app import celery_app | |
from celery.result import AsyncResult | |
import requests | |
torch.set_num_threads(1) | |
# Initialize FastAPI | |
app = FastAPI() | |
# Device configuration function | |
def get_device(): | |
if torch.backends.mps.is_available(): | |
return "mps" | |
elif torch.cuda.is_available(): | |
return "cuda" | |
else: | |
return "cpu" | |
# Device configuration | |
DEVICE = get_device() | |
# Initialize models at the module level | |
# Load models | |
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE) | |
mobile_sam.eval() | |
mobile_predictor = SamPredictor(mobile_sam) | |
# Load BrushNet and Stable Diffusion pipeline | |
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # Update to your base model path | |
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # Update to your BrushNet checkpoint path | |
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE) | |
pipe = StableDiffusionBrushNetPipeline.from_pretrained( | |
base_model_path, | |
brushnet=brushnet, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=False, | |
device=DEVICE | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload(device=DEVICE) | |
# Pydantic models for request and response validation | |
class InpaintingRequest(BaseModel): | |
input_image: str # Base64-encoded image | |
input_mask: Optional[str] = None # Base64-encoded mask image | |
prompt: str | |
negative_prompt: Optional[str] = 'ugly, low quality' | |
control_strength: float = 1.0 | |
guidance_scale: float = 12.0 | |
num_inference_steps: int = 50 | |
seed: int = 0 # Set default seed to 0 | |
randomize_seed: bool = False | |
blended: bool = False | |
invert_mask: bool = True | |
count: int = 1 | |
webhook_url: Optional[str] = None # URL to send the webhook to | |
class InpaintingResponse(BaseModel): | |
images: List[str] # Base64-encoded images | |
# Function to resize images | |
def resize_image(input_image, resolution): | |
H, W, C = input_image.shape | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
H = int(np.round(H / 64.0)) * 64 | |
W = int(np.round(W / 64.0)) * 64 | |
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img | |
# Main processing function | |
def process_api(request: InpaintingRequest): | |
global mobile_predictor, pipe, DEVICE | |
# Decode the base64-encoded input image | |
input_image_data = base64.b64decode(request.input_image) | |
original_image = Image.open(BytesIO(input_image_data)).convert("RGB") | |
original_image_np = np.array(original_image) | |
# Resize the original image if its smallest dimension is greater than 512 pixels | |
if min(original_image_np.shape[:2]) > 512: | |
original_image_np = resize_image(original_image_np, 512) | |
# Handle the mask image if provided | |
if request.input_mask: | |
input_mask_data = base64.b64decode(request.input_mask) | |
input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB") | |
input_mask_np = np.array(input_mask_image) | |
else: | |
raise ValueError("Input mask is required") | |
# Prepare other parameters | |
if request.randomize_seed: | |
seed = random.randint(0, 2147483647) | |
else: | |
seed = request.seed # This will always be an int due to the default value | |
generator = torch.Generator(DEVICE).manual_seed(seed) | |
# Prepare mask and init images | |
H, W = original_image_np.shape[:2] | |
original_mask = cv2.resize(input_mask_np, (W, H)) | |
if request.invert_mask: | |
original_mask = 255 - original_mask | |
mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis] | |
masked_image = original_image_np * (1 - mask) | |
init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB") | |
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB") | |
# Generate images | |
images = pipe( | |
[request.prompt] * request.count, | |
init_image, | |
mask_image, | |
num_inference_steps=request.num_inference_steps, | |
guidance_scale=request.guidance_scale, | |
generator=generator, | |
brushnet_conditioning_scale=float(request.control_strength), | |
negative_prompt=[request.negative_prompt] * request.count, | |
).images | |
# Blended images if requested | |
if request.blended: | |
if request.control_strength < 1.0: | |
raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed') | |
blended_images = [] | |
mask_np = np.array(mask_image) / 255.0 | |
mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0) | |
mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred) | |
for img in images: | |
img_np = np.array(img) | |
image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined | |
image_pasted = image_pasted.astype(np.uint8) | |
blended_images.append(Image.fromarray(image_pasted)) | |
images = blended_images | |
# Encode images to base64 | |
output_images = [] | |
for img in images: | |
buffered = BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
output_images.append(img_str) | |
return output_images | |
# Celery task function | |
def generate_image_task(self, request_data): | |
try: | |
# Convert request_data to InpaintingRequest object | |
request = InpaintingRequest(**request_data) | |
# Call the processing function | |
output_images = process_api(request) | |
# If webhook URL is provided, send the images to the webhook | |
webhook_url = request.webhook_url | |
if webhook_url: | |
payload = {'images': output_images} | |
requests.post(webhook_url, json=payload) | |
return output_images | |
except Exception as e: | |
# Optionally, log the error | |
print(f"Exception in generate_image_task: {e}") | |
# Re-raise the exception to let Celery handle it | |
raise | |
# API endpoint to start the image generation task | |
def inpaint(request: InpaintingRequest): | |
try: | |
task = generate_image_task.delay(request.dict()) | |
return {"job_id": task.id} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# API endpoint to check task status | |
def get_status(job_id: str): | |
task = AsyncResult(job_id, app=celery_app) | |
if task.state == states.PENDING: | |
return {"status": "PENDING"} | |
elif task.state == states.STARTED: | |
return {"status": "STARTED"} | |
elif task.state == states.SUCCESS: | |
return {"status": "SUCCESS"} | |
elif task.state == states.FAILURE: | |
# Extract the error message | |
error_info = task.info | |
error_message = str(error_info) | |
return {"status": "FAILURE", "error": error_message} | |
else: | |
return {"status": str(task.state)} | |
# API endpoint to retrieve results | |
def get_result(job_id: str): | |
task = AsyncResult(job_id, app=celery_app) | |
if task.state == states.SUCCESS: | |
return {"images": task.result} | |
elif task.state == states.FAILURE: | |
raise HTTPException(status_code=500, detail=str(task.info)) | |
else: | |
raise HTTPException(status_code=202, detail="Task not completed yet.") | |
# Run the app | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app_api:app", host="0.0.0.0", port=8000) | |