Spaces:
Paused
Paused
| import hashlib | |
| import os | |
| from io import BytesIO | |
| import base64 | |
| import gradio as gr | |
| from PIL import Image | |
| from cachetools import LRUCache | |
| import torch | |
| import numpy as np | |
| # Direct HairFast imports (no gRPC needed!) | |
| try: | |
| from hair_swap import HairFast, get_parser | |
| HAIRFAST_AVAILABLE = True | |
| print("✅ HairFast successfully imported!") | |
| except ImportError as e: | |
| print(f"❌ HairFast import failed: {e}") | |
| HAIRFAST_AVAILABLE = False | |
| from utils.shape_predictor import align_face | |
| # Global variables | |
| hair_fast_model = None | |
| align_cache = LRUCache(maxsize=10) | |
| def initialize_hairfast(): | |
| """Initialize HairFast model""" | |
| global hair_fast_model | |
| if not HAIRFAST_AVAILABLE: | |
| print("❌ HairFast not available") | |
| return False | |
| try: | |
| print("🔄 Initializing HairFast model...") | |
| # Get default arguments | |
| parser = get_parser() | |
| args = parser.parse_args([]) # Use default arguments | |
| # Override some settings for HF Spaces | |
| args.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| args.batch_size = 1 # Keep small for HF Spaces | |
| # Initialize HairFast | |
| hair_fast_model = HairFast(args) | |
| print(f"✅ HairFast initialized successfully on {args.device}!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ HairFast initialization failed: {e}") | |
| hair_fast_model = None | |
| return False | |
| def get_bytes(img): | |
| """Convert PIL Image to bytes""" | |
| if img is None: | |
| return img | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| return buffered.getvalue() | |
| def bytes_to_image(image: bytes) -> Image.Image: | |
| """Convert bytes to PIL Image""" | |
| image = Image.open(BytesIO(image)) | |
| return image | |
| def base64_to_image(base64_string): | |
| """Convert base64 string to PIL Image""" | |
| try: | |
| if base64_string.startswith('data:image'): | |
| base64_string = base64_string.split(',')[1] | |
| image_bytes = base64.b64decode(base64_string) | |
| return Image.open(BytesIO(image_bytes)) | |
| except Exception as e: | |
| print(f"Error converting base64 to image: {e}") | |
| return None | |
| def image_to_base64(image): | |
| """Convert PIL Image to base64 string""" | |
| if image is None: | |
| return None | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_bytes = buffered.getvalue() | |
| img_base64 = base64.b64encode(img_bytes).decode('utf-8') | |
| return f"data:image/jpeg;base64,{img_base64}" | |
| def pil_to_tensor(image): | |
| """Convert PIL to tensor for HairFast""" | |
| if isinstance(image, Image.Image): | |
| # Convert to tensor format expected by HairFast | |
| image_array = np.array(image) | |
| if image_array.max() > 1: | |
| image_array = image_array / 255.0 | |
| tensor = torch.from_numpy(image_array).permute(2, 0, 1).float() | |
| return tensor | |
| return image | |
| def tensor_to_pil(tensor): | |
| """Convert tensor to PIL Image""" | |
| if isinstance(tensor, torch.Tensor): | |
| if tensor.dim() == 4: | |
| tensor = tensor.squeeze(0) | |
| if tensor.dim() == 3: | |
| tensor = tensor.permute(1, 2, 0) | |
| tensor = tensor.detach().cpu().numpy() | |
| if tensor.max() <= 1: | |
| tensor = (tensor * 255).astype(np.uint8) | |
| return Image.fromarray(tensor) | |
| return tensor | |
| def center_crop(img): | |
| """Center crop image to square""" | |
| width, height = img.size | |
| side = min(width, height) | |
| left = (width - side) / 2 | |
| top = (height - side) / 2 | |
| right = (width + side) / 2 | |
| bottom = (height + side) / 2 | |
| img = img.crop((left, top, right, bottom)) | |
| return img | |
| def resize(name): | |
| """Image resize function with face alignment""" | |
| def resize_inner(img, align): | |
| global align_cache | |
| if name in align: | |
| img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
| if img_hash not in align_cache: | |
| try: | |
| img = align_face(img, return_tensors=False)[0] | |
| align_cache[img_hash] = img | |
| except Exception as e: | |
| print(f"Face alignment failed: {e}") | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| else: | |
| img = align_cache[img_hash] | |
| elif img.size != (1024, 1024): | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| return img | |
| return resize_inner | |
| def swap_hair_direct(face, shape, color, blending, poisson_iters, poisson_erosion): | |
| """Direct hair swapping using HairFast (no gRPC)""" | |
| global hair_fast_model | |
| # Initialize model if needed | |
| if hair_fast_model is None: | |
| if not initialize_hairfast(): | |
| return gr.update(visible=False), gr.update( | |
| value="❌ HairFast model not available. Please check if all model files are uploaded.", | |
| visible=True | |
| ) | |
| # Validation | |
| if not face and not shape and not color: | |
| return gr.update(visible=False), gr.update( | |
| value="Need to upload a face and at least a shape or color ❗", | |
| visible=True | |
| ) | |
| elif not face: | |
| return gr.update(visible=False), gr.update( | |
| value="Need to upload a face ❗", | |
| visible=True | |
| ) | |
| elif not shape and not color: | |
| return gr.update(visible=False), gr.update( | |
| value="Need to upload at least a shape or color ❗", | |
| visible=True | |
| ) | |
| try: | |
| print("🔄 Starting hair transfer...") | |
| # Use shape as color if color is not provided | |
| if color is None: | |
| color = shape | |
| if shape is None: | |
| shape = color | |
| # Direct HairFast inference | |
| result_tensor = hair_fast_model.swap( | |
| face_img=face, | |
| shape_img=shape, | |
| color_img=color, | |
| benchmark=False, | |
| align=True, # Use face alignment | |
| seed=3407 | |
| ) | |
| # Convert result tensor to PIL Image | |
| result_image = tensor_to_pil(result_tensor) | |
| print("✅ Hair transfer completed successfully!") | |
| return gr.update(value=result_image, visible=True), gr.update(visible=False) | |
| except Exception as e: | |
| error_msg = f"❌ Hair transfer failed: {str(e)}" | |
| print(error_msg) | |
| return gr.update(visible=False), gr.update(value=error_msg, visible=True) | |
| def hair_transfer_api(source_image, shape_image=None, color_image=None, | |
| blending="Article", poisson_iters=0, poisson_erosion=15): | |
| """API function for React integration""" | |
| global hair_fast_model | |
| try: | |
| # Handle base64 inputs | |
| if isinstance(source_image, str): | |
| source_image = base64_to_image(source_image) | |
| if isinstance(shape_image, str): | |
| shape_image = base64_to_image(shape_image) | |
| if isinstance(color_image, str): | |
| color_image = base64_to_image(color_image) | |
| # Initialize model if needed | |
| if hair_fast_model is None: | |
| if not initialize_hairfast(): | |
| return None, "❌ HairFast model not available" | |
| # Validation | |
| if source_image is None: | |
| return None, "❌ Source image is required" | |
| # Use source as reference if no references provided | |
| if shape_image is None and color_image is None: | |
| return None, "❌ At least shape or color reference image is required" | |
| if color_image is None: | |
| color_image = shape_image | |
| if shape_image is None: | |
| shape_image = color_image | |
| # Direct HairFast inference | |
| result_tensor = hair_fast_model.swap( | |
| face_img=source_image, | |
| shape_img=shape_image, | |
| color_img=color_image, | |
| benchmark=False, | |
| align=True, | |
| seed=3407 | |
| ) | |
| # Convert to PIL and then base64 | |
| result_image = tensor_to_pil(result_tensor) | |
| result_base64 = image_to_base64(result_image) | |
| return result_base64, "✅ Hair transfer completed successfully!" | |
| except Exception as e: | |
| error_msg = f"❌ API Error: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def get_demo(): | |
| """Create Gradio interface""" | |
| with gr.Blocks( | |
| title="HairFastGAN Direct API", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px;"> | |
| <h1>🎨 HairFastGAN - Direct Model Inference</h1> | |
| <p>High-quality hair transfer without gRPC dependency</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML("<h3>📤 Input Images</h3>") | |
| source = gr.Image( | |
| label="Source Photo (Person's Face)", | |
| type="pil" | |
| ) | |
| with gr.Row(): | |
| shape = gr.Image( | |
| label="Hair Shape Reference (Optional)", | |
| type="pil" | |
| ) | |
| color = gr.Image( | |
| label="Hair Color Reference (Optional)", | |
| type="pil" | |
| ) | |
| with gr.Accordion("🔧 Advanced Options", open=False): | |
| blending = gr.Radio( | |
| ["Article", "Alternative_v1", "Alternative_v2"], | |
| value='Article', | |
| label="Color Encoder Version" | |
| ) | |
| poisson_iters = gr.Slider( | |
| 0, 2500, value=0, step=1, | |
| label="Poisson Iterations", | |
| info="Detail recovery strength" | |
| ) | |
| poisson_erosion = gr.Slider( | |
| 1, 100, value=15, step=1, | |
| label="Poisson Erosion", | |
| info="Blending smoothness" | |
| ) | |
| align = gr.CheckboxGroup( | |
| ["Face", "Shape", "Color"], | |
| value=["Face", "Shape", "Color"], | |
| label="Face Alignment [Recommended]" | |
| ) | |
| btn = gr.Button("🎨 Transfer Hair Style", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.HTML("<h3>📥 Result</h3>") | |
| output = gr.Image(label="Result Image", type="pil") | |
| error_message = gr.Textbox( | |
| label="⚠️ Status", | |
| visible=False, | |
| elem_classes="error-message" | |
| ) | |
| # Example gallery | |
| gr.HTML("<h3>💡 Examples</h3>") | |
| gr.Examples( | |
| examples=[ | |
| ["input/0.png", "input/1.png", "input/2.png"], | |
| ["input/6.png", "input/7.png", None], | |
| ["input/10.jpg", None, "input/11.jpg"] | |
| ], | |
| inputs=[source, shape, color], | |
| outputs=output | |
| ) | |
| # Event handlers | |
| source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) | |
| shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) | |
| color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) | |
| btn.click( | |
| fn=swap_hair_direct, | |
| inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], | |
| outputs=[output, error_message], | |
| api_name="predict" # For React integration | |
| ) | |
| # Citation | |
| gr.Markdown(''' | |
| ### 📖 Citation | |
| ```bibtex | |
| @article{nikolaev2024hairfastgan, | |
| title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach}, | |
| author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek}, | |
| journal={arXiv preprint arXiv:2404.01094}, | |
| year={2024} | |
| } | |
| ``` | |
| ''') | |
| return demo | |
| if __name__ == '__main__': | |
| # Initialize cache | |
| align_cache = LRUCache(maxsize=10) | |
| # Create demo | |
| demo = get_demo() | |
| # Launch with API enabled | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=True, | |
| share=False | |
| ) |