import hashlib import os from io import BytesIO import base64 import gradio as gr from PIL import Image from cachetools import LRUCache import torch import numpy as np # Direct HairFast imports (no gRPC needed!) try: from hair_swap import HairFast, get_parser HAIRFAST_AVAILABLE = True print("✅ HairFast successfully imported!") except ImportError as e: print(f"❌ HairFast import failed: {e}") HAIRFAST_AVAILABLE = False from utils.shape_predictor import align_face # Global variables hair_fast_model = None align_cache = LRUCache(maxsize=10) def initialize_hairfast(): """Initialize HairFast model""" global hair_fast_model if not HAIRFAST_AVAILABLE: print("❌ HairFast not available") return False try: print("🔄 Initializing HairFast model...") # Get default arguments parser = get_parser() args = parser.parse_args([]) # Use default arguments # Override some settings for HF Spaces args.device = 'cuda' if torch.cuda.is_available() else 'cpu' args.batch_size = 1 # Keep small for HF Spaces # Initialize HairFast hair_fast_model = HairFast(args) print(f"✅ HairFast initialized successfully on {args.device}!") return True except Exception as e: print(f"❌ HairFast initialization failed: {e}") hair_fast_model = None return False def get_bytes(img): """Convert PIL Image to bytes""" if img is None: return img buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image: bytes) -> Image.Image: """Convert bytes to PIL Image""" image = Image.open(BytesIO(image)) return image def base64_to_image(base64_string): """Convert base64 string to PIL Image""" try: if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] image_bytes = base64.b64decode(base64_string) return Image.open(BytesIO(image_bytes)) except Exception as e: print(f"Error converting base64 to image: {e}") return None def image_to_base64(image): """Convert PIL Image to base64 string""" if image is None: return None buffered = BytesIO() image.save(buffered, format="JPEG") img_bytes = buffered.getvalue() img_base64 = base64.b64encode(img_bytes).decode('utf-8') return f"data:image/jpeg;base64,{img_base64}" def pil_to_tensor(image): """Convert PIL to tensor for HairFast""" if isinstance(image, Image.Image): # Convert to tensor format expected by HairFast image_array = np.array(image) if image_array.max() > 1: image_array = image_array / 255.0 tensor = torch.from_numpy(image_array).permute(2, 0, 1).float() return tensor return image def tensor_to_pil(tensor): """Convert tensor to PIL Image""" if isinstance(tensor, torch.Tensor): if tensor.dim() == 4: tensor = tensor.squeeze(0) if tensor.dim() == 3: tensor = tensor.permute(1, 2, 0) tensor = tensor.detach().cpu().numpy() if tensor.max() <= 1: tensor = (tensor * 255).astype(np.uint8) return Image.fromarray(tensor) return tensor def center_crop(img): """Center crop image to square""" width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 img = img.crop((left, top, right, bottom)) return img def resize(name): """Image resize function with face alignment""" def resize_inner(img, align): global align_cache if name in align: img_hash = hashlib.md5(get_bytes(img)).hexdigest() if img_hash not in align_cache: try: img = align_face(img, return_tensors=False)[0] align_cache[img_hash] = img except Exception as e: print(f"Face alignment failed: {e}") img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) else: img = align_cache[img_hash] elif img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img return resize_inner def swap_hair_direct(face, shape, color, blending, poisson_iters, poisson_erosion): """Direct hair swapping using HairFast (no gRPC)""" global hair_fast_model # Initialize model if needed if hair_fast_model is None: if not initialize_hairfast(): return gr.update(visible=False), gr.update( value="❌ HairFast model not available. Please check if all model files are uploaded.", visible=True ) # Validation if not face and not shape and not color: return gr.update(visible=False), gr.update( value="Need to upload a face and at least a shape or color ❗", visible=True ) elif not face: return gr.update(visible=False), gr.update( value="Need to upload a face ❗", visible=True ) elif not shape and not color: return gr.update(visible=False), gr.update( value="Need to upload at least a shape or color ❗", visible=True ) try: print("🔄 Starting hair transfer...") # Use shape as color if color is not provided if color is None: color = shape if shape is None: shape = color # Direct HairFast inference result_tensor = hair_fast_model.swap( face_img=face, shape_img=shape, color_img=color, benchmark=False, align=True, # Use face alignment seed=3407 ) # Convert result tensor to PIL Image result_image = tensor_to_pil(result_tensor) print("✅ Hair transfer completed successfully!") return gr.update(value=result_image, visible=True), gr.update(visible=False) except Exception as e: error_msg = f"❌ Hair transfer failed: {str(e)}" print(error_msg) return gr.update(visible=False), gr.update(value=error_msg, visible=True) def hair_transfer_api(source_image, shape_image=None, color_image=None, blending="Article", poisson_iters=0, poisson_erosion=15): """API function for React integration""" global hair_fast_model try: # Handle base64 inputs if isinstance(source_image, str): source_image = base64_to_image(source_image) if isinstance(shape_image, str): shape_image = base64_to_image(shape_image) if isinstance(color_image, str): color_image = base64_to_image(color_image) # Initialize model if needed if hair_fast_model is None: if not initialize_hairfast(): return None, "❌ HairFast model not available" # Validation if source_image is None: return None, "❌ Source image is required" # Use source as reference if no references provided if shape_image is None and color_image is None: return None, "❌ At least shape or color reference image is required" if color_image is None: color_image = shape_image if shape_image is None: shape_image = color_image # Direct HairFast inference result_tensor = hair_fast_model.swap( face_img=source_image, shape_img=shape_image, color_img=color_image, benchmark=False, align=True, seed=3407 ) # Convert to PIL and then base64 result_image = tensor_to_pil(result_tensor) result_base64 = image_to_base64(result_image) return result_base64, "✅ Hair transfer completed successfully!" except Exception as e: error_msg = f"❌ API Error: {str(e)}" print(error_msg) return None, error_msg def get_demo(): """Create Gradio interface""" with gr.Blocks( title="HairFastGAN Direct API", theme=gr.themes.Soft() ) as demo: gr.HTML("""

🎨 HairFastGAN - Direct Model Inference

High-quality hair transfer without gRPC dependency

""") with gr.Row(): with gr.Column(): gr.HTML("

📤 Input Images

") source = gr.Image( label="Source Photo (Person's Face)", type="pil" ) with gr.Row(): shape = gr.Image( label="Hair Shape Reference (Optional)", type="pil" ) color = gr.Image( label="Hair Color Reference (Optional)", type="pil" ) with gr.Accordion("🔧 Advanced Options", open=False): blending = gr.Radio( ["Article", "Alternative_v1", "Alternative_v2"], value='Article', label="Color Encoder Version" ) poisson_iters = gr.Slider( 0, 2500, value=0, step=1, label="Poisson Iterations", info="Detail recovery strength" ) poisson_erosion = gr.Slider( 1, 100, value=15, step=1, label="Poisson Erosion", info="Blending smoothness" ) align = gr.CheckboxGroup( ["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], label="Face Alignment [Recommended]" ) btn = gr.Button("🎨 Transfer Hair Style", variant="primary", size="lg") with gr.Column(): gr.HTML("

📥 Result

") output = gr.Image(label="Result Image", type="pil") error_message = gr.Textbox( label="⚠️ Status", visible=False, elem_classes="error-message" ) # Example gallery gr.HTML("

💡 Examples

") gr.Examples( examples=[ ["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None], ["input/10.jpg", None, "input/11.jpg"] ], inputs=[source, shape, color], outputs=output ) # Event handlers source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) btn.click( fn=swap_hair_direct, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], outputs=[output, error_message], api_name="predict" # For React integration ) # Citation gr.Markdown(''' ### 📖 Citation ```bibtex @article{nikolaev2024hairfastgan, title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach}, author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek}, journal={arXiv preprint arXiv:2404.01094}, year={2024} } ``` ''') return demo if __name__ == '__main__': # Initialize cache align_cache = LRUCache(maxsize=10) # Create demo demo = get_demo() # Launch with API enabled demo.launch( server_name="0.0.0.0", server_port=7860, show_api=True, share=False )