Spaces:
Running
on
Zero
Running
on
Zero
| # Optional face enhance nodes | |
| # region imports | |
| import sys | |
| from pathlib import Path | |
| import comfy.model_management as model_management | |
| import cv2 | |
| import insightface | |
| import numpy as np | |
| import onnxruntime | |
| import torch | |
| from insightface.model_zoo.inswapper import INSwapper | |
| from PIL import Image | |
| from ..errors import ModelNotFound | |
| from ..log import NullWriter, mklog | |
| from ..utils import download_antelopev2, get_model_path, pil2tensor, tensor2pil | |
| # endregion | |
| log = mklog(__name__) | |
| class MTB_LoadFaceAnalysisModel: | |
| """Loads a face analysis model""" | |
| models = [] | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "faceswap_model": ( | |
| ["antelopev2", "buffalo_l", "buffalo_m", "buffalo_sc"], | |
| {"default": "buffalo_l"}, | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("FACE_ANALYSIS_MODEL",) | |
| FUNCTION = "load_model" | |
| CATEGORY = "mtb/facetools" | |
| DEPRECATED = True | |
| def load_model(self, faceswap_model: str): | |
| if faceswap_model == "antelopev2": | |
| download_antelopev2() | |
| face_analyser = insightface.app.FaceAnalysis( | |
| name=faceswap_model, | |
| root=get_model_path("insightface").as_posix(), | |
| ) | |
| return (face_analyser,) | |
| class MTB_LoadFaceSwapModel: | |
| """Loads a faceswap model""" | |
| def get_models() -> list[Path]: | |
| models_path = get_model_path("insightface") | |
| if models_path.exists(): | |
| models = models_path.iterdir() | |
| return [x for x in models if x.suffix in [".onnx", ".pth"]] | |
| return [] | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "faceswap_model": ( | |
| [x.name for x in cls.get_models()], | |
| {"default": "None"}, | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("FACESWAP_MODEL",) | |
| FUNCTION = "load_model" | |
| CATEGORY = "mtb/facetools" | |
| DEPRECATED = True | |
| def load_model(self, faceswap_model: str): | |
| model_path = get_model_path("insightface", faceswap_model) | |
| if not model_path or not model_path.exists(): | |
| raise ModelNotFound(f"{faceswap_model} ({model_path})") | |
| log.info(f"Loading model {model_path}") | |
| return ( | |
| INSwapper( | |
| model_path, | |
| onnxruntime.InferenceSession( | |
| path_or_bytes=model_path, | |
| providers=onnxruntime.get_available_providers(), | |
| ), | |
| ), | |
| ) | |
| # region roop node | |
| class MTB_FaceSwap: | |
| """Face swap using deepinsight/insightface models""" | |
| model = None | |
| model_path = None | |
| def __init__(self) -> None: | |
| pass | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "image": ("IMAGE",), | |
| "reference": ("IMAGE",), | |
| "faces_index": ("STRING", {"default": "0"}), | |
| "faceanalysis_model": ( | |
| "FACE_ANALYSIS_MODEL", | |
| {"default": "None"}, | |
| ), | |
| "faceswap_model": ("FACESWAP_MODEL", {"default": "None"}), | |
| }, | |
| "optional": { | |
| "preserve_alpha": ("BOOLEAN", {"default": True}), | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| FUNCTION = "swap" | |
| CATEGORY = "mtb/facetools" | |
| DEPRECATED = True | |
| def swap( | |
| self, | |
| image: torch.Tensor, | |
| reference: torch.Tensor, | |
| faces_index: str, | |
| faceanalysis_model, | |
| faceswap_model, | |
| preserve_alpha=False, | |
| ): | |
| def do_swap(img): | |
| model_management.throw_exception_if_processing_interrupted() | |
| img = tensor2pil(img)[0] | |
| ref = tensor2pil(reference)[0] | |
| alpha_channel = None | |
| if preserve_alpha and img.mode == "RGBA": | |
| alpha_channel = img.getchannel("A") | |
| img = img.convert("RGB") | |
| face_ids = { | |
| int(x) | |
| for x in faces_index.strip(",").split(",") | |
| if x.isnumeric() | |
| } | |
| sys.stdout = NullWriter() | |
| swapped = swap_face( | |
| faceanalysis_model, ref, img, faceswap_model, face_ids | |
| ) | |
| sys.stdout = sys.__stdout__ | |
| if alpha_channel: | |
| swapped.putalpha(alpha_channel) | |
| return pil2tensor(swapped) | |
| batch_count = image.size(0) | |
| log.info(f"Running insightface swap (batch size: {batch_count})") | |
| if reference.size(0) != 1: | |
| raise ValueError("Reference image must have batch size 1") | |
| if batch_count == 1: | |
| image = do_swap(image) | |
| else: | |
| image_batch = [do_swap(image[i]) for i in range(batch_count)] | |
| image = torch.cat(image_batch, dim=0) | |
| return (image,) | |
| # endregion | |
| # region face swap utils | |
| def get_face_single( | |
| face_analyser, img_data: np.ndarray, face_index=0, det_size=(640, 640) | |
| ): | |
| face_analyser.prepare(ctx_id=0, det_size=det_size) | |
| face = face_analyser.get(img_data) | |
| if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: | |
| log.debug("No face ed, trying again with smaller image") | |
| det_size_half = (det_size[0] // 2, det_size[1] // 2) | |
| return get_face_single( | |
| face_analyser, | |
| img_data, | |
| face_index=face_index, | |
| det_size=det_size_half, | |
| ) | |
| try: | |
| return sorted(face, key=lambda x: x.bbox[0])[face_index] | |
| except IndexError: | |
| return None | |
| def swap_face( | |
| face_analyser, | |
| source_img: Image.Image | list[Image.Image], | |
| target_img: Image.Image | list[Image.Image], | |
| face_swapper_model, | |
| faces_index: set[int] | None = None, | |
| ) -> Image.Image: | |
| if faces_index is None: | |
| faces_index = {0} | |
| log.debug(f"Swapping faces: {faces_index}") | |
| result_image = target_img | |
| if face_swapper_model is not None: | |
| cv_source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) | |
| cv_target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) | |
| source_face = get_face_single( | |
| face_analyser, cv_source_img, face_index=0 | |
| ) | |
| if source_face is not None: | |
| result = cv_target_img | |
| for face_num in faces_index: | |
| target_face = get_face_single( | |
| face_analyser, cv_target_img, face_index=face_num | |
| ) | |
| if target_face is not None: | |
| sys.stdout = NullWriter() | |
| result = face_swapper_model.get( | |
| result, target_face, source_face | |
| ) | |
| sys.stdout = sys.__stdout__ | |
| else: | |
| log.warning(f"No target face found for {face_num}") | |
| result_image = Image.fromarray( | |
| cv2.cvtColor(result, cv2.COLOR_BGR2RGB) | |
| ) | |
| else: | |
| log.warning("No source face found") | |
| else: | |
| log.error("No face swap model provided") | |
| return result_image | |
| # endregion face swap utils | |
| __nodes__ = [MTB_FaceSwap, MTB_LoadFaceSwapModel, MTB_LoadFaceAnalysisModel] | |