HairSwapModel / app.py
miguelmuzo's picture
Update app.py
3d5b2b1 verified
raw
history blame
12.7 kB
import hashlib
import os
from io import BytesIO
import base64
import gradio as gr
from PIL import Image
from cachetools import LRUCache
import torch
import numpy as np
# Direct HairFast imports (no gRPC needed!)
try:
from hair_swap import HairFast, get_parser
HAIRFAST_AVAILABLE = True
print("✅ HairFast successfully imported!")
except ImportError as e:
print(f"❌ HairFast import failed: {e}")
HAIRFAST_AVAILABLE = False
from utils.shape_predictor import align_face
# Global variables
hair_fast_model = None
align_cache = LRUCache(maxsize=10)
def initialize_hairfast():
"""Initialize HairFast model"""
global hair_fast_model
if not HAIRFAST_AVAILABLE:
print("❌ HairFast not available")
return False
try:
print("🔄 Initializing HairFast model...")
# Get default arguments
parser = get_parser()
args = parser.parse_args([]) # Use default arguments
# Override some settings for HF Spaces
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.batch_size = 1 # Keep small for HF Spaces
# Initialize HairFast
hair_fast_model = HairFast(args)
print(f"✅ HairFast initialized successfully on {args.device}!")
return True
except Exception as e:
print(f"❌ HairFast initialization failed: {e}")
hair_fast_model = None
return False
def get_bytes(img):
"""Convert PIL Image to bytes"""
if img is None:
return img
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image: bytes) -> Image.Image:
"""Convert bytes to PIL Image"""
image = Image.open(BytesIO(image))
return image
def base64_to_image(base64_string):
"""Convert base64 string to PIL Image"""
try:
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_bytes = base64.b64decode(base64_string)
return Image.open(BytesIO(image_bytes))
except Exception as e:
print(f"Error converting base64 to image: {e}")
return None
def image_to_base64(image):
"""Convert PIL Image to base64 string"""
if image is None:
return None
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
def pil_to_tensor(image):
"""Convert PIL to tensor for HairFast"""
if isinstance(image, Image.Image):
# Convert to tensor format expected by HairFast
image_array = np.array(image)
if image_array.max() > 1:
image_array = image_array / 255.0
tensor = torch.from_numpy(image_array).permute(2, 0, 1).float()
return tensor
return image
def tensor_to_pil(tensor):
"""Convert tensor to PIL Image"""
if isinstance(tensor, torch.Tensor):
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
if tensor.dim() == 3:
tensor = tensor.permute(1, 2, 0)
tensor = tensor.detach().cpu().numpy()
if tensor.max() <= 1:
tensor = (tensor * 255).astype(np.uint8)
return Image.fromarray(tensor)
return tensor
def center_crop(img):
"""Center crop image to square"""
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
img = img.crop((left, top, right, bottom))
return img
def resize(name):
"""Image resize function with face alignment"""
def resize_inner(img, align):
global align_cache
if name in align:
img_hash = hashlib.md5(get_bytes(img)).hexdigest()
if img_hash not in align_cache:
try:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
except Exception as e:
print(f"Face alignment failed: {e}")
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
return resize_inner
def swap_hair_direct(face, shape, color, blending, poisson_iters, poisson_erosion):
"""Direct hair swapping using HairFast (no gRPC)"""
global hair_fast_model
# Initialize model if needed
if hair_fast_model is None:
if not initialize_hairfast():
return gr.update(visible=False), gr.update(
value="❌ HairFast model not available. Please check if all model files are uploaded.",
visible=True
)
# Validation
if not face and not shape and not color:
return gr.update(visible=False), gr.update(
value="Need to upload a face and at least a shape or color ❗",
visible=True
)
elif not face:
return gr.update(visible=False), gr.update(
value="Need to upload a face ❗",
visible=True
)
elif not shape and not color:
return gr.update(visible=False), gr.update(
value="Need to upload at least a shape or color ❗",
visible=True
)
try:
print("🔄 Starting hair transfer...")
# Use shape as color if color is not provided
if color is None:
color = shape
if shape is None:
shape = color
# Direct HairFast inference
result_tensor = hair_fast_model.swap(
face_img=face,
shape_img=shape,
color_img=color,
benchmark=False,
align=True, # Use face alignment
seed=3407
)
# Convert result tensor to PIL Image
result_image = tensor_to_pil(result_tensor)
print("✅ Hair transfer completed successfully!")
return gr.update(value=result_image, visible=True), gr.update(visible=False)
except Exception as e:
error_msg = f"❌ Hair transfer failed: {str(e)}"
print(error_msg)
return gr.update(visible=False), gr.update(value=error_msg, visible=True)
def hair_transfer_api(source_image, shape_image=None, color_image=None,
blending="Article", poisson_iters=0, poisson_erosion=15):
"""API function for React integration"""
global hair_fast_model
try:
# Handle base64 inputs
if isinstance(source_image, str):
source_image = base64_to_image(source_image)
if isinstance(shape_image, str):
shape_image = base64_to_image(shape_image)
if isinstance(color_image, str):
color_image = base64_to_image(color_image)
# Initialize model if needed
if hair_fast_model is None:
if not initialize_hairfast():
return None, "❌ HairFast model not available"
# Validation
if source_image is None:
return None, "❌ Source image is required"
# Use source as reference if no references provided
if shape_image is None and color_image is None:
return None, "❌ At least shape or color reference image is required"
if color_image is None:
color_image = shape_image
if shape_image is None:
shape_image = color_image
# Direct HairFast inference
result_tensor = hair_fast_model.swap(
face_img=source_image,
shape_img=shape_image,
color_img=color_image,
benchmark=False,
align=True,
seed=3407
)
# Convert to PIL and then base64
result_image = tensor_to_pil(result_tensor)
result_base64 = image_to_base64(result_image)
return result_base64, "✅ Hair transfer completed successfully!"
except Exception as e:
error_msg = f"❌ API Error: {str(e)}"
print(error_msg)
return None, error_msg
def get_demo():
"""Create Gradio interface"""
with gr.Blocks(
title="HairFastGAN Direct API",
theme=gr.themes.Soft()
) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>🎨 HairFastGAN - Direct Model Inference</h1>
<p>High-quality hair transfer without gRPC dependency</p>
</div>
""")
with gr.Row():
with gr.Column():
gr.HTML("<h3>📤 Input Images</h3>")
source = gr.Image(
label="Source Photo (Person's Face)",
type="pil"
)
with gr.Row():
shape = gr.Image(
label="Hair Shape Reference (Optional)",
type="pil"
)
color = gr.Image(
label="Hair Color Reference (Optional)",
type="pil"
)
with gr.Accordion("🔧 Advanced Options", open=False):
blending = gr.Radio(
["Article", "Alternative_v1", "Alternative_v2"],
value='Article',
label="Color Encoder Version"
)
poisson_iters = gr.Slider(
0, 2500, value=0, step=1,
label="Poisson Iterations",
info="Detail recovery strength"
)
poisson_erosion = gr.Slider(
1, 100, value=15, step=1,
label="Poisson Erosion",
info="Blending smoothness"
)
align = gr.CheckboxGroup(
["Face", "Shape", "Color"],
value=["Face", "Shape", "Color"],
label="Face Alignment [Recommended]"
)
btn = gr.Button("🎨 Transfer Hair Style", variant="primary", size="lg")
with gr.Column():
gr.HTML("<h3>📥 Result</h3>")
output = gr.Image(label="Result Image", type="pil")
error_message = gr.Textbox(
label="⚠️ Status",
visible=False,
elem_classes="error-message"
)
# Example gallery
gr.HTML("<h3>💡 Examples</h3>")
gr.Examples(
examples=[
["input/0.png", "input/1.png", "input/2.png"],
["input/6.png", "input/7.png", None],
["input/10.jpg", None, "input/11.jpg"]
],
inputs=[source, shape, color],
outputs=output
)
# Event handlers
source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
btn.click(
fn=swap_hair_direct,
inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
outputs=[output, error_message],
api_name="predict" # For React integration
)
# Citation
gr.Markdown('''
### 📖 Citation
```bibtex
@article{nikolaev2024hairfastgan,
title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
journal={arXiv preprint arXiv:2404.01094},
year={2024}
}
```
''')
return demo
if __name__ == '__main__':
# Initialize cache
align_cache = LRUCache(maxsize=10)
# Create demo
demo = get_demo()
# Launch with API enabled
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=True,
share=False
)