Spaces:
Paused
Paused
| import os | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision.transforms import transforms | |
| from torchvision.utils import save_image | |
| from models.Net import get_segmentation | |
| def equal_replacer(images: list[torch.Tensor]) -> list[torch.Tensor]: | |
| for i in range(len(images)): | |
| if images[i].dtype is torch.uint8: | |
| images[i] = images[i] / 255 | |
| for i in range(len(images)): | |
| for j in range(i + 1, len(images)): | |
| if torch.allclose(images[i], images[j]): | |
| images[j] = images[i] | |
| return images | |
| class DilateErosion: | |
| def __init__(self, dilate_erosion=5, device='cuda'): | |
| self.dilate_erosion = dilate_erosion | |
| self.weight = torch.Tensor([ | |
| [False, True, False], | |
| [True, True, True], | |
| [False, True, False] | |
| ]).float()[None, None, ...].to(device) | |
| def hair_from_mask(self, mask): | |
| mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask)) | |
| mask = F.interpolate(mask, size=(256, 256), mode='nearest') | |
| dilate, erosion = self.mask(mask) | |
| return dilate, erosion | |
| def mask(self, mask): | |
| masks = mask.clone().repeat(*([2] + [1] * (len(mask.shape) - 1))).float() | |
| sum_w = self.weight.sum().item() | |
| n = len(mask) | |
| for _ in range(self.dilate_erosion): | |
| masks = F.conv2d(masks, self.weight, | |
| bias=None, stride=1, padding='same', dilation=1, groups=1) | |
| masks[:n] = (masks[:n] > 0).float() | |
| masks[n:] = (masks[n:] == sum_w).float() | |
| hair_mask_dilate, hair_mask_erode = masks[:n], masks[n:] | |
| return hair_mask_dilate, hair_mask_erode | |
| def poisson_image_blending(final_image, face_image, dilate_erosion=30, maxn=115): | |
| dilate_erosion = DilateErosion(dilate_erosion=dilate_erosion) | |
| transform = transforms.ToTensor() | |
| if isinstance(face_image, str): | |
| face_image = transform(Image.open(face_image)) | |
| elif not isinstance(face_image, torch.Tensor): | |
| face_image = transform(face_image) | |
| final_mask = get_segmentation(final_image.cuda().unsqueeze(0), resize=False) | |
| face_mask = get_segmentation(face_image.cuda().unsqueeze(0), resize=False) | |
| hair_target = torch.where(final_mask == 13, torch.ones_like(final_mask), | |
| torch.zeros_like(final_mask)) | |
| hair_face = torch.where(face_mask == 13, torch.ones_like(face_mask), | |
| torch.zeros_like(face_mask)) | |
| final_mask = F.interpolate(((1 - hair_target) * (1 - hair_face)).float(), size=(1024, 1024), mode='bicubic') | |
| dilation, _ = dilate_erosion.mask(1 - final_mask) | |
| mask_save = 1 - dilation[0] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| final_image_path = os.path.join(temp_dir, 'final_image.png') | |
| face_image_path = os.path.join(temp_dir, 'face_image.png') | |
| mask_path = os.path.join(temp_dir, 'mask_save.png') | |
| save_image(final_image, final_image_path) | |
| save_image(face_image, face_image_path) | |
| save_image(mask_save, mask_path) | |
| out_image_path = os.path.join(temp_dir, 'out_image_path.png') | |
| result = subprocess.run( | |
| ["fpie", "-s", face_image_path, "-m", mask_path, "-t", final_image_path, "-o", out_image_path, "-n", | |
| str(maxn), "-b", "taichi-gpu", "-g", "max"], | |
| check=True | |
| ) | |
| return Image.open(out_image_path), Image.open(mask_path) | |
| def list_image_files(directory): | |
| image_extensions = ['.jpg', '.jpeg', '.png'] | |
| image_files = [] | |
| for entry in sorted(os.listdir(directory)): | |
| file_path = os.path.join(directory, entry) | |
| if os.path.isfile(file_path): | |
| file_extension = Path(file_path).suffix.lower() | |
| if file_extension in image_extensions: | |
| image_files.append(entry) | |
| return image_files | |