Spaces:
Running
on
T4
Running
on
T4
| import gradio as gr | |
| import zombie | |
| from huggingface_hub import hf_hub_download | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| from faceparsing import get_face_mask | |
| # import torch | |
| # from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports | |
| # # --- 1. Load your pix2pixHD model --- | |
| # # You'll need to adapt this part to your specific model loading logic | |
| # # This is a simplified example | |
| # model = YourPix2PixHDModel() | |
| # model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth')) | |
| # model.eval() | |
| model_path = hf_hub_download(repo_id="jbrownkramer/makemeazombie", filename="smaller512x512_32bit.onnx") | |
| ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider']) | |
| # --- 2. Define the prediction function --- | |
| # def predict(input_image): | |
| # return input_image[..., ::-1] | |
| # # # Pre-process the input image | |
| # # processed_image = load_image(input_image) | |
| # # # Run inference | |
| # # with torch.no_grad(): | |
| # # generated_image_tensor = model(processed_image) | |
| # # # Post-process the output tensor to an image | |
| # # output_image = tensor2im(generated_image_tensor) | |
| # # return output_image | |
| def predict(input_image, mode): | |
| if mode == "Classic": | |
| # Use the transition_onnx function for side-by-side comparison | |
| zombie_image = zombie.transition_onnx(input_image, ort_session) | |
| if zombie_image is None: | |
| return "No face found" | |
| return zombie_image | |
| elif mode == "In Place": | |
| # Use the make_faces_zombie_from_array function for in-place transformation | |
| #zombie_image = zombie.make_faces_zombie_from_array(im_array, None, ort_session) | |
| #if zombie_image is None: | |
| # return "No face found" | |
| #return zombie_image | |
| face_mask = get_face_mask(input_image) | |
| return face_mask | |
| else: | |
| return "Invalid mode selected" | |
| # --- 3. Create the Gradio Interface --- | |
| title = "Make Me A Zombie" | |
| description = "Upload an image to see the pix2pixHD model in action." | |
| article = """<p style='text-align: center'>Model based on the <a href='https://github.com/NVIDIA/pix2pixHD' target='_blank'>pix2pixHD repository</a>. | |
| More details at <a href='https://makemeazombie.com' target='_blank'>makemeazombie.com</a>.</p>""" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown(choices=["Classic", "In Place"], value="Classic", label="Mode") | |
| ], | |
| outputs=gr.Image(type="pil", label="Output Image"), | |
| title=title, | |
| description=description, | |
| article=article, | |
| ) | |
| #demo.launch() | |
| demo.launch(debug=True) |