Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import PIL | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| # import for face detection | |
| import retinaface | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from diffusers import UniPCMultistepScheduler | |
| from modelconfig import ModelConfig | |
| from spiga.inference.framework import SPIGAFramework | |
| import spiga.demo.analyze.track.retinasort.config as cfg | |
| import matplotlib.pyplot as plt | |
| from matplotlib.path import Path | |
| import matplotlib.patches as patches | |
| # Bounding boxes | |
| config = cfg.cfg_retinasort | |
| face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'], | |
| device='cuda' if torch.cuda.is_available() else 'cpu', | |
| extra_features=config['retina']['extra_features'], | |
| cfg_postreat=config['retina']['postreat']) | |
| # Landmark extraction | |
| spiga_extractor = SPIGAFramework(ModelConfig("300wpublic", False)) | |
| uncanny_controlnet = ControlNetModel.from_pretrained( | |
| "multimodalart/uncannyfaces_25K", torch_dtype=torch.float16 | |
| ) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1-base", controlnet=uncanny_controlnet, safety_checker=None, torch_dtype=torch.float16 | |
| ) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to("cuda") | |
| # Generator seed, | |
| generator = torch.manual_seed(0) | |
| canvas_html = "<face-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>" | |
| load_js = """ | |
| async () => { | |
| const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js" | |
| fetch(url) | |
| .then(res => res.text()) | |
| .then(text => { | |
| const script = document.createElement('script'); | |
| script.type = "module" | |
| script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); | |
| document.head.appendChild(script); | |
| }); | |
| } | |
| """ | |
| get_js_image = """ | |
| async (image_in_img, prompt, image_file_live_opt, live_conditioning) => { | |
| const canvasEl = document.getElementById("canvas-root"); | |
| const imageData = canvasEl? canvasEl._data : null; | |
| return [image_in_img, prompt, image_file_live_opt, imageData] | |
| } | |
| """ | |
| def get_bounding_box(image): | |
| pil_image = Image.fromarray(image) | |
| face_detector.set_input_shape(pil_image.size[1], pil_image.size[0]) | |
| features = face_detector.inference(pil_image) | |
| if (features is None) and (len(features['bbox']) <= 0): | |
| raise Exception("No face detected") | |
| # get the first face detected | |
| bbox = features['bbox'][0] | |
| x1, y1, x2, y2 = bbox[:4] | |
| bbox_wh = [x1, y1, x2-x1, y2-y1] | |
| return bbox_wh | |
| def get_landmarks(image, bbox): | |
| features = spiga_extractor.inference(image, [bbox]) | |
| return features['landmarks'][0] | |
| def get_patch(landmarks, color='lime', closed=False): | |
| contour = landmarks | |
| ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1) | |
| facecolor = (0, 0, 0, 0) # Transparent fill color, if open | |
| if closed: | |
| contour.append(contour[0]) | |
| ops.append(Path.CLOSEPOLY) | |
| facecolor = color | |
| path = Path(contour, ops) | |
| return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4) | |
| def conditioning_from_landmarks(landmarks, size=512): | |
| # Precisely control output image size | |
| dpi = 72 | |
| fig, ax = plt.subplots( | |
| 1, figsize=[size/dpi, size/dpi], tight_layout={'pad': 0}) | |
| fig.set_dpi(dpi) | |
| black = np.zeros((size, size, 3)) | |
| ax.imshow(black) | |
| face_patch = get_patch(landmarks[0:17]) | |
| l_eyebrow = get_patch(landmarks[17:22], color='yellow') | |
| r_eyebrow = get_patch(landmarks[22:27], color='yellow') | |
| nose_v = get_patch(landmarks[27:31], color='orange') | |
| nose_h = get_patch(landmarks[31:36], color='orange') | |
| l_eye = get_patch(landmarks[36:42], color='magenta', closed=True) | |
| r_eye = get_patch(landmarks[42:48], color='magenta', closed=True) | |
| outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True) | |
| inner_lips = get_patch(landmarks[60:68], color='blue', closed=True) | |
| ax.add_patch(face_patch) | |
| ax.add_patch(l_eyebrow) | |
| ax.add_patch(r_eyebrow) | |
| ax.add_patch(nose_v) | |
| ax.add_patch(nose_h) | |
| ax.add_patch(l_eye) | |
| ax.add_patch(r_eye) | |
| ax.add_patch(outer_lips) | |
| ax.add_patch(inner_lips) | |
| plt.axis('off') | |
| fig.canvas.draw() | |
| buffer, (width, height) = fig.canvas.print_to_buffer() | |
| assert width == height | |
| assert width == size | |
| buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4)) | |
| buffer = buffer[:, :, 0:3] | |
| plt.close(fig) | |
| return PIL.Image.fromarray(buffer) | |
| def get_conditioning(image): | |
| # Steps: convert to BGR and then: | |
| # - Retrieve bounding box using `dlib` | |
| # - Obtain landmarks using `spiga` | |
| # - Create conditioning image with custom `matplotlib` code | |
| # TODO: error if bbox is too small | |
| image.thumbnail((512, 512)) | |
| image = np.array(image) | |
| image = image[:, :, ::-1] | |
| bbox = get_bounding_box(image) | |
| landmarks = get_landmarks(image, bbox) | |
| spiga_seg = conditioning_from_landmarks(landmarks) | |
| return spiga_seg | |
| def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None): | |
| if image_in_img is None and 'image' not in live_conditioning: | |
| raise gr.Error("Please provide an image") | |
| try: | |
| if image_file_live_opt == 'file': | |
| conditioning = get_conditioning(image_in_img) | |
| elif image_file_live_opt == 'webcam': | |
| base64_img = live_conditioning['image'] | |
| image_data = base64.b64decode(base64_img.split(',')[1]) | |
| conditioning = Image.open(BytesIO(image_data)).convert( | |
| 'RGB').resize((512, 512)) | |
| output = pipe( | |
| prompt, | |
| conditioning, | |
| generator=generator, | |
| num_images_per_prompt=3, | |
| num_inference_steps=20, | |
| ) | |
| return [conditioning] + output.images | |
| except Exception as e: | |
| raise gr.Error(str(e)) | |
| def toggle(choice): | |
| if choice == "file": | |
| return gr.update(visible=True, value=None), gr.update(visible=False, value=None) | |
| elif choice == "webcam": | |
| return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html) | |
| with gr.Blocks() as blocks: | |
| gr.Markdown(""" | |
| ## Generate Uncanny Faces with ControlNet Stable Diffusion | |
| [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet) | |
| """) | |
| with gr.Row(): | |
| live_conditioning = gr.JSON(value={}, visible=False) | |
| with gr.Column(): | |
| image_file_live_opt = gr.Radio(["file", "webcam"], value="file", | |
| label="How would you like to upload your image?") | |
| image_in_img = gr.Image(source="upload", visible=True, type="pil") | |
| canvas = gr.HTML(None, elem_id="canvas_html", visible=False) | |
| image_file_live_opt.change(fn=toggle, | |
| inputs=[image_file_live_opt], | |
| outputs=[image_in_img, canvas], | |
| queue=False) | |
| prompt = gr.Textbox( | |
| label="Enter your prompt", | |
| max_lines=1, | |
| placeholder="best quality, extremely detailed", | |
| ) | |
| run_button = gr.Button("Generate") | |
| with gr.Column(): | |
| gallery = gr.Gallery().style(grid=[2], height="auto") | |
| run_button.click(fn=generate_images, | |
| inputs=[image_in_img, prompt, | |
| image_file_live_opt, live_conditioning], | |
| outputs=[gallery], | |
| _js=get_js_image) | |
| blocks.load(None, None, None, _js=load_js) | |
| gr.Examples(fn=generate_images, | |
| examples=[ | |
| ["./examples/pedro-512.jpg", | |
| "Highly detailed photograph of young woman smiling, with palm trees in the background"], | |
| ["./examples/image1.jpg", | |
| "Highly detailed photograph of a scary clown"], | |
| ["./examples/image0.jpg", | |
| "Highly detailed photograph of Madonna"], | |
| ], | |
| inputs=[image_in_img, prompt], | |
| outputs=[gallery], | |
| cache_examples=True) | |
| gr.Markdown(''' | |
| This Space was trained on synthetic 3D faces to learn how to keep a pose - however it also learned that all faces are synthetic 3D faces, [learn more on our blog](https://huggingface.co/blog/train-your-controlnet), it uses a custom visualization based on SPIGA face landmarks for conditioning. | |
| ''') | |
| blocks.launch() | |