File size: 2,593 Bytes
71b16c4
 
 
 
 
 
 
0ab38e4
 
 
 
 
 
 
 
 
 
 
 
 
71b16c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab38e4
 
 
 
 
71b16c4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Dict, Any
import torch
import base64
import io
from PIL import Image
from diffusers import AutoPipelineForImage2Image

# torch==2.5.1+cu124
# torchvision==0.18.0+cu124
# torchaudio==2.5.1+cu124
# diffusers==0.17.0.dev0
# Pillow==10.0.0
# fastapi
# pydantic
# uvicorn

#torchvision==0.18.0
# torchaudio==2.5.1


class EndpointHandler:
    def __init__(self, path=""):
        """Initialize the model from the given path."""
        self.pipeline = AutoPipelineForImage2Image.from_pretrained(
            "cjwalch/kandinsky-endpoint", 
            torch_dtype=torch.float16, 
            use_safetensors=True
        )
        self.pipeline.enable_model_cpu_offload()
        if torch.cuda.is_available():
            self.pipeline.to("cuda")
        
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Run inference on the input image and return a base64-encoded result."""
        try:
            # Extract input parameters
            prompt = data.get("inputs", "")
            strength = float(data.get("strength", 0.6))
            guidance_scale = float(data.get("guidance_scale", 7.0))
            negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
            
            # Decode base64 image
            init_image_b64 = data.get("init_image", None)
            if not init_image_b64:
                return {"error": "Missing 'init_image' in input data"}
            
            image_bytes = base64.b64decode(init_image_b64)
            init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            # Generate output image
            output_image = self.pipeline(
                prompt=prompt, 
                image=init_image, 
                strength=strength, 
                guidance_scale=guidance_scale, 
                negative_prompt=negative_prompt
            ).images[0]
            
            # Convert to base64
            buffered = io.BytesIO()
            output_image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            
            # Clear any cache and memory used by the model after inference
            torch.cuda.empty_cache()  # Clears GPU memory
            del output_image  # Delete the output image from memory
            del init_image  # Delete the input image from memory
            
            return {"generated_image": img_str}
        
        except Exception as e:
            return {"error": str(e)}