Spaces:
Build error
Build error
File size: 8,032 Bytes
bddcd79 6454648 bddcd79 bfefba6 bddcd79 e5f65c8 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 f3a0251 bddcd79 bfefba6 bddcd79 fc34ff7 bfefba6 fc34ff7 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 bfefba6 bddcd79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# app_api.py
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import base64
from io import BytesIO
import random
import numpy as np
import cv2
from PIL import Image
import torch
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
from celery import states
from celery_app import celery_app
from celery.result import AsyncResult
import requests
torch.set_num_threads(1)
# Initialize FastAPI
app = FastAPI()
# Device configuration function
def get_device():
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
return "cpu"
# Device configuration
DEVICE = get_device()
# Initialize models at the module level
# Load models
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE)
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
# Load BrushNet and Stable Diffusion pipeline
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # Update to your base model path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # Update to your BrushNet checkpoint path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path,
brushnet=brushnet,
torch_dtype=torch.float16,
low_cpu_mem_usage=False,
device=DEVICE
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=DEVICE)
# Pydantic models for request and response validation
class InpaintingRequest(BaseModel):
input_image: str # Base64-encoded image
input_mask: Optional[str] = None # Base64-encoded mask image
prompt: str
negative_prompt: Optional[str] = 'ugly, low quality'
control_strength: float = 1.0
guidance_scale: float = 12.0
num_inference_steps: int = 50
seed: int = 0 # Set default seed to 0
randomize_seed: bool = False
blended: bool = False
invert_mask: bool = True
count: int = 1
webhook_url: Optional[str] = None # URL to send the webhook to
class InpaintingResponse(BaseModel):
images: List[str] # Base64-encoded images
# Function to resize images
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
# Main processing function
def process_api(request: InpaintingRequest):
global mobile_predictor, pipe, DEVICE
# Decode the base64-encoded input image
input_image_data = base64.b64decode(request.input_image)
original_image = Image.open(BytesIO(input_image_data)).convert("RGB")
original_image_np = np.array(original_image)
# Resize the original image if its smallest dimension is greater than 512 pixels
if min(original_image_np.shape[:2]) > 512:
original_image_np = resize_image(original_image_np, 512)
# Handle the mask image if provided
if request.input_mask:
input_mask_data = base64.b64decode(request.input_mask)
input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB")
input_mask_np = np.array(input_mask_image)
else:
raise ValueError("Input mask is required")
# Prepare other parameters
if request.randomize_seed:
seed = random.randint(0, 2147483647)
else:
seed = request.seed # This will always be an int due to the default value
generator = torch.Generator(DEVICE).manual_seed(seed)
# Prepare mask and init images
H, W = original_image_np.shape[:2]
original_mask = cv2.resize(input_mask_np, (W, H))
if request.invert_mask:
original_mask = 255 - original_mask
mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis]
masked_image = original_image_np * (1 - mask)
init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
# Generate images
images = pipe(
[request.prompt] * request.count,
init_image,
mask_image,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
generator=generator,
brushnet_conditioning_scale=float(request.control_strength),
negative_prompt=[request.negative_prompt] * request.count,
).images
# Blended images if requested
if request.blended:
if request.control_strength < 1.0:
raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed')
blended_images = []
mask_np = np.array(mask_image) / 255.0
mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0)
mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred)
for img in images:
img_np = np.array(img)
image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined
image_pasted = image_pasted.astype(np.uint8)
blended_images.append(Image.fromarray(image_pasted))
images = blended_images
# Encode images to base64
output_images = []
for img in images:
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
output_images.append(img_str)
return output_images
# Celery task function
@celery_app.task(bind=True)
def generate_image_task(self, request_data):
try:
# Convert request_data to InpaintingRequest object
request = InpaintingRequest(**request_data)
# Call the processing function
output_images = process_api(request)
# If webhook URL is provided, send the images to the webhook
webhook_url = request.webhook_url
if webhook_url:
payload = {'images': output_images}
requests.post(webhook_url, json=payload)
return output_images
except Exception as e:
# Optionally, log the error
print(f"Exception in generate_image_task: {e}")
# Re-raise the exception to let Celery handle it
raise
# API endpoint to start the image generation task
@app.post("/inpaint")
def inpaint(request: InpaintingRequest):
try:
task = generate_image_task.delay(request.dict())
return {"job_id": task.id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# API endpoint to check task status
@app.get("/status/{job_id}")
def get_status(job_id: str):
task = AsyncResult(job_id, app=celery_app)
if task.state == states.PENDING:
return {"status": "PENDING"}
elif task.state == states.STARTED:
return {"status": "STARTED"}
elif task.state == states.SUCCESS:
return {"status": "SUCCESS"}
elif task.state == states.FAILURE:
# Extract the error message
error_info = task.info
error_message = str(error_info)
return {"status": "FAILURE", "error": error_message}
else:
return {"status": str(task.state)}
# API endpoint to retrieve results
@app.get("/result/{job_id}")
def get_result(job_id: str):
task = AsyncResult(job_id, app=celery_app)
if task.state == states.SUCCESS:
return {"images": task.result}
elif task.state == states.FAILURE:
raise HTTPException(status_code=500, detail=str(task.info))
else:
raise HTTPException(status_code=202, detail="Task not completed yet.")
# Run the app
if __name__ == "__main__":
import uvicorn
uvicorn.run("app_api:app", host="0.0.0.0", port=8000)
|