File size: 2,401 Bytes
ea54d45
b636aa5
ea54d45
b636aa5
 
906db1e
f4b717d
b636aa5
906db1e
7754b09
41181c9
 
 
ea54d45
fd8d501
 
 
f4b717d
 
fd8d501
ea54d45
1269c65
 
 
 
7754b09
edd8452
b636aa5
7754b09
b636aa5
 
edd8452
fd8d501
b636aa5
 
1269c65
a44f1cf
41181c9
a44f1cf
b636aa5
 
 
 
 
41181c9
b636aa5
7754b09
906db1e
b636aa5
41181c9
b636aa5
 
 
 
1269c65
e728996
 
 
41181c9
 
e728996
 
b636aa5
e728996
f4b717d
 
 
 
 
fd8d501
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
from PIL import Image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging

class EndpointHandler:
    def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
        # Set memory limit
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

        # Access the environment variable
        HF_TOKEN = os.getenv('HF_TOKEN')
        if not HF_TOKEN:
            raise ValueError("HF_TOKEN environment variable is not set")

        logging.basicConfig(level=logging.INFO)
        logging.info("Using HF_TOKEN")

        # Clear GPU memory
        torch.cuda.empty_cache()

        # Load model and pipeline
        self.controlnet = FluxControlNetModel.from_pretrained(
            model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
        )
        self.pipe = FluxControlNetPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            controlnet=self.controlnet,
            torch_dtype=torch.float16,
            use_auth_token=HF_TOKEN
        )
        self.pipe.to("cuda")
        self.pipe.enable_attention_slicing("auto")
        self.pipe.enable_sequential_cpu_offload()
        self.pipe.enable_memory_efficient_attention()

    def preprocess(self, data):
        image_file = data.get("control_image", None)
        if not image_file:
            raise ValueError("Missing control_image in input.")
        image = Image.open(image_file)
        return image.resize((512, 512))  # Resize to reduce memory usage

    def postprocess(self, output):
        buffer = BytesIO()
        output.save(buffer, format="PNG")
        buffer.seek(0)
        return buffer

    def inference(self, data):
        control_image = self.preprocess(data)
        torch.cuda.empty_cache()
        output_image = self.pipe(
            prompt=data.get("prompt", ""),
            control_image=control_image,
            controlnet_conditioning_scale=0.5,
            num_inference_steps=10,
            height=control_image.size[1],
            width=control_image.size[0],
        ).images[0]
        return self.postprocess(output_image)

if __name__ == "__main__":
    data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
    handler = EndpointHandler()
    output = handler.inference(data)
    print(output)