import gradio as gr import zombie from huggingface_hub import hf_hub_download import onnxruntime as ort import numpy as np from PIL import Image from faceparsing import get_face_mask # import torch # from your_pix2pixhd_code import YourPix2PixHDModel, load_image, tensor2im # Adapt these imports # # --- 1. Load your pix2pixHD model --- # # You'll need to adapt this part to your specific model loading logic # # This is a simplified example # model = YourPix2PixHDModel() # model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth')) # model.eval() model_path = hf_hub_download(repo_id="jbrownkramer/makemeazombie", filename="smaller512x512_32bit.onnx") ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider']) # --- 2. Define the prediction function --- # def predict(input_image): # return input_image[..., ::-1] # # # Pre-process the input image # # processed_image = load_image(input_image) # # # Run inference # # with torch.no_grad(): # # generated_image_tensor = model(processed_image) # # # Post-process the output tensor to an image # # output_image = tensor2im(generated_image_tensor) # # return output_image def predict(input_image, mode): if mode == "Classic": # Use the transition_onnx function for side-by-side comparison zombie_image = zombie.transition_onnx(input_image, ort_session) if zombie_image is None: return "No face found" return zombie_image elif mode == "In Place": # Use the make_faces_zombie_from_array function for in-place transformation #zombie_image = zombie.make_faces_zombie_from_array(im_array, None, ort_session) #if zombie_image is None: # return "No face found" #return zombie_image face_mask = get_face_mask(input_image) return face_mask else: return "Invalid mode selected" # --- 3. Create the Gradio Interface --- title = "Make Me A Zombie" description = "Upload an image to see the pix2pixHD model in action." article = """
Model based on the pix2pixHD repository. More details at makemeazombie.com.
""" demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=["Classic", "In Place"], value="Classic", label="Mode") ], outputs=gr.Image(type="pil", label="Output Image"), title=title, description=description, article=article, ) #demo.launch() demo.launch(debug=True)