File size: 8,032 Bytes
bddcd79
6454648
 
bddcd79
 
 
 
 
 
 
 
 
 
bfefba6
bddcd79
 
e5f65c8
 
bddcd79
 
 
bfefba6
 
bddcd79
 
 
bfefba6
 
 
 
 
 
 
 
bddcd79
 
bfefba6
bddcd79
bfefba6
bddcd79
 
 
 
 
 
 
 
 
 
 
bfefba6
 
 
 
 
bddcd79
 
 
 
 
 
 
 
 
 
 
 
 
 
bfefba6
bddcd79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfefba6
 
bddcd79
 
 
 
 
f3a0251
 
 
 
bddcd79
 
 
 
 
 
bfefba6
bddcd79
 
fc34ff7
 
 
bfefba6
fc34ff7
bddcd79
 
 
bfefba6
 
bddcd79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfefba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddcd79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfefba6
 
 
 
bddcd79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# app_api.py
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import base64
from io import BytesIO
import random
import numpy as np
import cv2
from PIL import Image
import torch
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
from celery import states
from celery_app import celery_app
from celery.result import AsyncResult
import requests

torch.set_num_threads(1)

# Initialize FastAPI
app = FastAPI()

# Device configuration function
def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

# Device configuration
DEVICE = get_device()

# Initialize models at the module level
# Load models
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE)
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)

# Load BrushNet and Stable Diffusion pipeline
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"  # Update to your base model path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"  # Update to your BrushNet checkpoint path

brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
    base_model_path,
    brushnet=brushnet,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False,
    device=DEVICE
)

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=DEVICE)

# Pydantic models for request and response validation
class InpaintingRequest(BaseModel):
    input_image: str  # Base64-encoded image
    input_mask: Optional[str] = None  # Base64-encoded mask image
    prompt: str
    negative_prompt: Optional[str] = 'ugly, low quality'
    control_strength: float = 1.0
    guidance_scale: float = 12.0
    num_inference_steps: int = 50
    seed: int = 0  # Set default seed to 0
    randomize_seed: bool = False
    blended: bool = False
    invert_mask: bool = True
    count: int = 1
    webhook_url: Optional[str] = None  # URL to send the webhook to

class InpaintingResponse(BaseModel):
    images: List[str]  # Base64-encoded images

# Function to resize images
def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

# Main processing function
def process_api(request: InpaintingRequest):
    global mobile_predictor, pipe, DEVICE

    # Decode the base64-encoded input image
    input_image_data = base64.b64decode(request.input_image)
    original_image = Image.open(BytesIO(input_image_data)).convert("RGB")
    original_image_np = np.array(original_image)

    # Resize the original image if its smallest dimension is greater than 512 pixels
    if min(original_image_np.shape[:2]) > 512:
        original_image_np = resize_image(original_image_np, 512)

    # Handle the mask image if provided
    if request.input_mask:
        input_mask_data = base64.b64decode(request.input_mask)
        input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB")
        input_mask_np = np.array(input_mask_image)
    else:
        raise ValueError("Input mask is required")

    # Prepare other parameters
    if request.randomize_seed:
        seed = random.randint(0, 2147483647)
    else:
        seed = request.seed  # This will always be an int due to the default value

    generator = torch.Generator(DEVICE).manual_seed(seed)

    # Prepare mask and init images
    H, W = original_image_np.shape[:2]
    original_mask = cv2.resize(input_mask_np, (W, H))

    if request.invert_mask:
        original_mask = 255 - original_mask

    mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis]
    masked_image = original_image_np * (1 - mask)

    init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
    mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")

    # Generate images
    images = pipe(
        [request.prompt] * request.count,
        init_image,
        mask_image,
        num_inference_steps=request.num_inference_steps,
        guidance_scale=request.guidance_scale,
        generator=generator,
        brushnet_conditioning_scale=float(request.control_strength),
        negative_prompt=[request.negative_prompt] * request.count,
    ).images

    # Blended images if requested
    if request.blended:
        if request.control_strength < 1.0:
            raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed')
        blended_images = []
        mask_np = np.array(mask_image) / 255.0
        mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0)
        mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred)
        for img in images:
            img_np = np.array(img)
            image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined
            image_pasted = image_pasted.astype(np.uint8)
            blended_images.append(Image.fromarray(image_pasted))
        images = blended_images

    # Encode images to base64
    output_images = []
    for img in images:
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
        output_images.append(img_str)

    return output_images

# Celery task function
@celery_app.task(bind=True)
def generate_image_task(self, request_data):
    try:
        # Convert request_data to InpaintingRequest object
        request = InpaintingRequest(**request_data)
        # Call the processing function
        output_images = process_api(request)
        # If webhook URL is provided, send the images to the webhook
        webhook_url = request.webhook_url
        if webhook_url:
            payload = {'images': output_images}
            requests.post(webhook_url, json=payload)
        return output_images
    except Exception as e:
        # Optionally, log the error
        print(f"Exception in generate_image_task: {e}")
        # Re-raise the exception to let Celery handle it
        raise

# API endpoint to start the image generation task
@app.post("/inpaint")
def inpaint(request: InpaintingRequest):
    try:
        task = generate_image_task.delay(request.dict())
        return {"job_id": task.id}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# API endpoint to check task status
@app.get("/status/{job_id}")
def get_status(job_id: str):
    task = AsyncResult(job_id, app=celery_app)
    if task.state == states.PENDING:
        return {"status": "PENDING"}
    elif task.state == states.STARTED:
        return {"status": "STARTED"}
    elif task.state == states.SUCCESS:
        return {"status": "SUCCESS"}
    elif task.state == states.FAILURE:
        # Extract the error message
        error_info = task.info
        error_message = str(error_info)
        return {"status": "FAILURE", "error": error_message}
    else:
        return {"status": str(task.state)}

# API endpoint to retrieve results
@app.get("/result/{job_id}")
def get_result(job_id: str):
    task = AsyncResult(job_id, app=celery_app)
    if task.state == states.SUCCESS:
        return {"images": task.result}
    elif task.state == states.FAILURE:
        raise HTTPException(status_code=500, detail=str(task.info))
    else:
        raise HTTPException(status_code=202, detail="Task not completed yet.")

# Run the app
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app_api:app", host="0.0.0.0", port=8000)