|
import os |
|
import torch |
|
from PIL import Image |
|
from diffusers import FluxControlNetModel |
|
from diffusers.pipelines import FluxControlNetPipeline |
|
from io import BytesIO |
|
import logging |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"): |
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
if not HF_TOKEN: |
|
raise ValueError("HF_TOKEN environment variable is not set") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logging.info("Using HF_TOKEN") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.controlnet = FluxControlNetModel.from_pretrained( |
|
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN |
|
) |
|
self.pipe = FluxControlNetPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
controlnet=self.controlnet, |
|
torch_dtype=torch.float16, |
|
use_auth_token=HF_TOKEN |
|
) |
|
self.pipe.to("cuda") |
|
self.pipe.enable_attention_slicing("auto") |
|
self.pipe.enable_sequential_cpu_offload() |
|
self.pipe.enable_memory_efficient_attention() |
|
|
|
def preprocess(self, data): |
|
image_file = data.get("control_image", None) |
|
if not image_file: |
|
raise ValueError("Missing control_image in input.") |
|
image = Image.open(image_file) |
|
return image.resize((512, 512)) |
|
|
|
def postprocess(self, output): |
|
buffer = BytesIO() |
|
output.save(buffer, format="PNG") |
|
buffer.seek(0) |
|
return buffer |
|
|
|
def inference(self, data): |
|
control_image = self.preprocess(data) |
|
torch.cuda.empty_cache() |
|
output_image = self.pipe( |
|
prompt=data.get("prompt", ""), |
|
control_image=control_image, |
|
controlnet_conditioning_scale=0.5, |
|
num_inference_steps=10, |
|
height=control_image.size[1], |
|
width=control_image.size[0], |
|
).images[0] |
|
return self.postprocess(output_image) |
|
|
|
if __name__ == "__main__": |
|
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'} |
|
handler = EndpointHandler() |
|
output = handler.inference(data) |
|
print(output) |
|
|