ouclxy commited on
Commit
8ca3766
·
verified ·
1 Parent(s): bd4a200

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +51 -14
  2. dataset_mv.py +2236 -0
  3. download.py +8 -0
  4. gradio_app.py +379 -0
  5. requirements.txt +67 -0
  6. test_stablehairv2.py +320 -0
README.md CHANGED
@@ -1,14 +1,51 @@
1
- ---
2
- title: Stablehairv2 Demo
3
- emoji: 📈
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.44.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: This is a simple demo showing our Huawei cup work
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StableHair v2
2
+ **Stable-Hair v2: Real-World Hair Transfer via Multiple-View Diffusion Model**
3
+ Kuiyuan Sun*, [Yuxuan Zhang*](https://xiaojiu-z.github.io/YuxuanZhang.github.io/), [Jichao Zhang*](https://zhangqianhui.github.io/), [Jiaming Liu](https://scholar.google.com/citations?user=SmL7oMQAAAAJ&hl=en),
4
+ [Wei Wang](https://weiwangtrento.github.io/), [Nicu Sebe](http://disi.unitn.it/~sebe/), [Yao Zhao](https://scholar.google.com/citations?user=474TbQYAAAAJ&hl=en&oi=ao)<br>
5
+ *Equal Contribution <br>
6
+ Beijing Jiaotong University, Shanghai Jiaotong University, Ocean University of China, Tiamat AI, University of Trento <br>
7
+ [Arxiv](https://arxiv.org/abs/2507.07591), [Project]()<br>
8
+
9
+
10
+ Bald | Reference | Multiple View | Original Video
11
+ ![](./imgs/multiview1.gif)
12
+ Bald | Reference | Multiple View | Original Video
13
+ ![](./imgs/multiview2.gif)
14
+
15
+ ## Environments
16
+
17
+ ```
18
+ conda create -n stablehairv2 python=3.10
19
+ ```
20
+ ```
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ ## Results
25
+
26
+ <img src="./imgs/teaser.jpg" width="800">
27
+
28
+ ## Pretrained Model
29
+ | Name | Model |
30
+ |----------------------------|:---------:|
31
+ | motion_module-41400000.pth | [:link:](https://drive.google.com/file/d/1AZMhui9jNRF3Z0N72VDPOwDd0JafLQ3B/view?usp=drive_link) |
32
+ | pytorch_model_1.bin | [:link:](https://drive.google.com/file/d/1FwKPZI8lvdlZqu8R1aJ-QbE55kxHPHjU/view?usp=drive_link) |
33
+ | pytorch_model_2.bin | [:link:](https://drive.google.com/file/d/1h3dXlo8lhZN3ee5aN0shZmpLfn5itVou/view?usp=drive_link) |
34
+ | pytorch_model_3.bin | [:link:](https://drive.google.com/file/d/1jARfXaU6wiur85Vm1JxZ_xye0FfrUiqb/view?usp=drive_link) |
35
+ | pytorch_model.bin | [:link:](https://drive.google.com/file/d/1zXXf13pV5IOn2vrV6DGI9hliEFvuPrYf/view?usp=drive_link) |
36
+
37
+ ### Multiple View Hair Transfer
38
+
39
+ Please use ``gdown''' to download the pretrained model and save it in your model_path
40
+ ```
41
+
42
+ python test_stablehairv2.py --pretrained_model_name_or_path "stable-diffusion-v1-5/stable-diffusion-v1-5" \
43
+ --image_encoder "openai/clip-vit-large-patch14" --output_dir [Your_output_dir] \
44
+ --num_validation_images 1 --validation_ids ./test_imgs/bald.jpg \
45
+ --validation_hairs ./test_imgs/ref1.jpg --model_path [Your_model_path]
46
+ ```
47
+
48
+
49
+ # Our V1 version
50
+
51
+ StableHair v2 is an improved version of [StableHair](https://github.com/Xiaojiu-z/Stable-Hair) (AAAI 2025)
dataset_mv.py ADDED
@@ -0,0 +1,2236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import random
7
+ import albumentations as A
8
+
9
+ pixel_transform = A.Compose([
10
+ A.SmallestMaxSize(max_size=512),
11
+ A.CenterCrop(512, 512),
12
+ A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
13
+ ], additional_targets={'image0': 'image', 'image1': 'image'})
14
+
15
+ hair_transform = A.Compose([
16
+ A.SmallestMaxSize(max_size=512),
17
+ A.CenterCrop(512, 512),
18
+ A.Affine(scale=(0.9, 1.2), rotate=(-10, 10), p=0.7)]
19
+ )
20
+
21
+ # class myDataset(data.Dataset):
22
+ # """Custom data.Dataset compatible with data.DataLoader."""
23
+
24
+ # def __init__(self, train_data_dir):
25
+ # self.img_path = os.path.join(train_data_dir, "hair")
26
+ # # self.pose_path = os.path.join(train_data_dir, "pose.npy")
27
+ # # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
28
+ # # self.ref_path = os.path.join(train_data_dir, "ref_hair")
29
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
30
+ # self.non_hair_path = os.path.join(train_data_dir, "non-hair")
31
+ # self.ref_path = os.path.join(train_data_dir, "reference")
32
+ # self.lists = os.listdir(self.img_path)
33
+ # self.len = len(self.lists)-10
34
+ # self.pose = np.load(self.pose_path)
35
+ # #self.pose = np.random.randn(12, 4)
36
+
37
+ # def __getitem__(self, index):
38
+ # """Returns one data pair (source and target)."""
39
+ # # seq_len, fea_dim
40
+ # random_number1 = random.randrange(0, 21)
41
+ # random_number2 = random.randrange(0, 21)
42
+
43
+ # while random_number2 == random_number1:
44
+ # random_number2 = random.randrange(0, 21)
45
+ # name = self.lists[index]
46
+
47
+ # random_number1 = random_number1
48
+ # #* 10
49
+ # #random_number2 = random_number2 * 10
50
+
51
+ # random_number2 = random_number1
52
+
53
+
54
+ # non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
55
+ # ref_folder = os.path.join(self.ref_path, name)
56
+
57
+ # files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
58
+ # ref_path = os.path.join(ref_folder, files[0])
59
+
60
+ # img_non_hair = cv2.imread(non_hair_path)
61
+ # ref_hair = cv2.imread(ref_path)
62
+
63
+
64
+ # img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
65
+ # ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
66
+
67
+
68
+ # img_non_hair = cv2.resize(img_non_hair, (512, 512))
69
+ # ref_hair = cv2.resize(ref_hair, (512, 512))
70
+
71
+
72
+ # img_non_hair = (img_non_hair / 255.0) * 2 - 1
73
+ # ref_hair = (ref_hair / 255.0) * 2 - 1
74
+
75
+
76
+ # img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
77
+ # ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
78
+
79
+ # pose1 = self.pose[random_number1]
80
+ # pose1 = torch.tensor(pose1)
81
+ # pose2 = self.pose[random_number2]
82
+ # pose2 = torch.tensor(pose2)
83
+ # hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
84
+ # hair_num = [0, 2, 6, 14, 18, 21]
85
+ # img_hair_stack = []
86
+ # for i in hair_num:
87
+ # img_hair = cv2.imread(hair_path)
88
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
89
+ # img_hair = cv2.resize(img_hair, (512, 512))
90
+ # img_hair = (img_hair / 255.0) * 2 - 1
91
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
92
+ # img_hair_stack.append(img_hair)
93
+ # img_hair = torch.stack(img_hair_stack)
94
+
95
+ # return {
96
+ # 'hair_pose': pose1,
97
+ # 'img_hair': img_hair,
98
+ # 'bald_pose': pose2,
99
+ # 'img_non_hair': img_non_hair,
100
+ # 'ref_hair': ref_hair
101
+ # }
102
+
103
+ # def __len__(self):
104
+ # return self.len
105
+
106
+ class myDataset(data.Dataset):
107
+ """Custom data.Dataset compatible with data.DataLoader."""
108
+ def __init__(self, train_data_dir):
109
+ self.img_path = os.path.join(train_data_dir, "hair")
110
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
111
+ self.non_hair_path = os.path.join(train_data_dir, "no_hair")
112
+ self.ref_path = os.path.join(train_data_dir, "ref_hair")
113
+
114
+ self.lists = os.listdir(self.img_path)
115
+ self.len = len(self.lists)
116
+ self.pose = np.load(self.pose_path)
117
+
118
+ def __getitem__(self, index):
119
+ """Returns one data pair (source and target)."""
120
+ # seq_len, fea_dim
121
+ random_number1 = random.randrange(0, 120)
122
+ random_number2 = random.randrange(0, 120)
123
+ while random_number2==random_number1:
124
+ random_number2 = random.randrange(0, 120)
125
+ name = self.lists[index]
126
+
127
+ hair_path = os.path.join(self.img_path, name, str(random_number1)+'.jpg')
128
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2)+'.jpg')
129
+ ref_folder = os.path.join(self.ref_path, name)
130
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
131
+ ref_path = os.path.join(ref_folder, files[0])
132
+ img_hair = cv2.imread(hair_path)
133
+ img_non_hair = cv2.imread(non_hair_path)
134
+ ref_hair = cv2.imread(ref_path)
135
+
136
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
137
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
138
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
139
+
140
+ img_hair = cv2.resize(img_hair, (512, 512))
141
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
142
+ ref_hair = cv2.resize(ref_hair, (512, 512))
143
+ img_hair = (img_hair/255.0)* 2 - 1
144
+ img_non_hair = (img_non_hair/255.0)
145
+ ref_hair = (ref_hair/255.0)* 2 - 1
146
+
147
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
148
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
149
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
150
+
151
+ pose1 = self.pose[random_number1]
152
+ pose1 = torch.tensor(pose1)
153
+ pose2 = self.pose[random_number2]
154
+ pose2 = torch.tensor(pose2)
155
+
156
+ return {
157
+ 'hair_pose': pose1,
158
+ 'img_hair':img_hair,
159
+ 'bald_pose': pose2,
160
+ 'img_non_hair':img_non_hair,
161
+ 'ref_hair':ref_hair
162
+ }
163
+
164
+ def __len__(self):
165
+ return self.len
166
+
167
+
168
+ # class myDataset_unet(data.Dataset):
169
+ # """Custom data.Dataset compatible with data.DataLoader."""
170
+
171
+ # class myDataset_unet(data.Dataset):
172
+ # """Custom data.Dataset compatible with data.DataLoader."""
173
+
174
+ # def __init__(self, train_data_dir, frame_num=6):
175
+ # self.img_path = os.path.join(train_data_dir, "hair")
176
+ # # self.pose_path = os.path.join(train_data_dir, "pose.npy")
177
+ # # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
178
+ # # self.ref_path = os.path.join(train_data_dir, "ref_hair")
179
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
180
+ # self.non_hair_path = os.path.join(train_data_dir, "non-hair")
181
+ # self.ref_path = os.path.join(train_data_dir, "reference")
182
+ # self.lists = os.listdir(self.img_path)
183
+ # self.len = len(self.lists)-10
184
+ # self.pose = np.load(self.pose_path)
185
+ # self.frame_num = frame_num
186
+ # #self.pose = np.random.randn(12, 4)
187
+
188
+ # def __getitem__(self, index):
189
+ # """Returns one data pair (source and target)."""
190
+ # # seq_len, fea_dim
191
+ # random_number1 = random.randrange(0, 21)
192
+ # random_number2 = random.randrange(0, 21)
193
+
194
+ # while random_number2 == random_number1:
195
+ # random_number2 = random.randrange(0, 21)
196
+ # name = self.lists[index]
197
+
198
+ # random_number1 = random_number1
199
+ # #* 10
200
+ # #random_number2 = random_number2 * 10
201
+
202
+ # random_number2 = random_number1
203
+
204
+
205
+ # non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
206
+ # ref_folder = os.path.join(self.ref_path, name)
207
+ # ref_folder = os.path.join(self.img_path, name)
208
+
209
+ # files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
210
+ # #ref_path = os.path.join(ref_folder, files[0])
211
+ # ref_path = os.path.join(ref_folder, '0.jpg')
212
+
213
+ # img_non_hair = cv2.imread(non_hair_path)
214
+ # ref_hair = cv2.imread(ref_path)
215
+
216
+
217
+ # img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
218
+ # ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
219
+
220
+
221
+ # img_non_hair = cv2.resize(img_non_hair, (512, 512))
222
+ # ref_hair = cv2.resize(ref_hair, (512, 512))
223
+
224
+
225
+ # img_non_hair = (img_non_hair / 255.0) * 2 - 1
226
+ # ref_hair = (ref_hair / 255.0) * 2 - 1
227
+
228
+
229
+ # img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
230
+ # ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
231
+
232
+ # pose1 = self.pose[random_number1]
233
+ # pose1 = torch.tensor(pose1)
234
+ # pose2 = self.pose[random_number2]
235
+ # pose2 = torch.tensor(pose2)
236
+ # hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
237
+ # hair_num = [0, 2, 6, 14, 18, 21]
238
+ # img_hair_stack = []
239
+ # # begin = random.randrange(0, 21-self.frame_num)
240
+ # # hair_num = [i+begin for i in range(self.frame_num)]
241
+ # for i in hair_num:
242
+ # img_hair = cv2.imread(hair_path)
243
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
244
+ # img_hair = cv2.resize(img_hair, (512, 512))
245
+ # img_hair = (img_hair / 255.0) * 2 - 1
246
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
247
+ # img_hair_stack.append(img_hair)
248
+ # img_hair = torch.stack(img_hair_stack)
249
+
250
+ # return {
251
+ # 'hair_pose': pose1,
252
+ # 'img_hair': img_hair,
253
+ # 'bald_pose': pose2,
254
+ # 'img_non_hair': img_non_hair,
255
+ # 'ref_hair': ref_hair
256
+ # }
257
+
258
+ # def __len__(self):
259
+ # return self.len
260
+
261
+ class myDataset_unet(data.Dataset):
262
+ """Custom data.Dataset compatible with data.DataLoader."""
263
+
264
+ def __init__(self, train_data_dir):
265
+ self.img_path = os.path.join(train_data_dir, "hair")
266
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
267
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
268
+ self.ref_path = os.path.join(train_data_dir, "reference")
269
+ self.lists = os.listdir(self.img_path)
270
+ self.len = len(self.lists)
271
+ self.pose = np.load(self.pose_path)
272
+ elevations_deg = [-0.05/2*np.pi*360] * 21
273
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
274
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
275
+ for i in Face_yaws:
276
+ if i<0:
277
+ i = 2*np.pi+i
278
+ i = i/2*np.pi*360
279
+ face_yaws = [Face_yaws[0]]
280
+ for i in range(20):
281
+ face_yaws.append(Face_yaws[3*i+2])
282
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
283
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
284
+ self.azimuths_rad[:-1].sort()
285
+
286
+
287
+ def __getitem__(self, index):
288
+ """Returns one data pair (source and target)."""
289
+ # seq_len, fea_dim
290
+ random_number1 = random.randrange(0, 21)
291
+ random_number2 = random.randrange(0, 21)
292
+
293
+ # while random_number2 == random_number1:
294
+ # random_number2 = random.randrange(0, 21)
295
+ name = self.lists[index]
296
+
297
+ #random_number1 = random_number1
298
+ #random_number2 = random_number2 * 10
299
+
300
+ #random_number2 = random_number1
301
+
302
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
303
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
304
+ ref_folder = os.path.join(self.img_path, name)
305
+
306
+ #files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
307
+ ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
308
+ img_hair = cv2.imread(hair_path)
309
+ img_non_hair = cv2.imread(non_hair_path)
310
+ ref_hair = cv2.imread(ref_path)
311
+
312
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
313
+ # img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
314
+ # ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
315
+
316
+ img_hair = cv2.resize(img_hair, (512, 512))
317
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
318
+ ref_hair = cv2.resize(ref_hair, (512, 512))
319
+
320
+ img_hair = (img_hair / 255.0) * 2 - 1
321
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
322
+ ref_hair = (ref_hair / 255.0) * 2 - 1
323
+
324
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
325
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
326
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
327
+
328
+ # pose1 = self.pose[random_number1]
329
+ # pose1 = torch.tensor(pose1)
330
+ # pose2 = self.pose[random_number2]
331
+ # pose2 = torch.tensor(pose2)
332
+ # polars = self.polars_rad[random_number1]
333
+ # polars = torch.tensor(polars).unsqueeze(0)
334
+ # azimuths = self.azimuths_rad[random_number1]
335
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
336
+ pose = self.pose[random_number1]
337
+ pose = torch.tensor(pose)
338
+
339
+ return {
340
+ # 'hair_pose': pose1,
341
+ 'img_hair': img_hair,
342
+ # 'bald_pose': pose2,
343
+ # 'img_non_hair': img_non_hair,
344
+ 'img_ref': ref_hair,
345
+ 'pose': pose,
346
+ # 'polars': polars,
347
+ # 'azimuths': azimuths,
348
+ }
349
+
350
+ def __len__(self):
351
+ return self.len-10
352
+
353
+ class myDataset_sv3d(data.Dataset):
354
+ """Custom data.Dataset compatible with data.DataLoader."""
355
+
356
+ def __init__(self, train_data_dir, frame_num=6):
357
+ self.img_path = os.path.join(train_data_dir, "hair")
358
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
359
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
360
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
361
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
362
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
363
+ self.ref_path = os.path.join(train_data_dir, "reference")
364
+ self.lists = os.listdir(self.img_path)
365
+ self.len = len(self.lists)-10
366
+ self.pose = np.load(self.pose_path)
367
+ self.frame_num = frame_num
368
+ #self.pose = np.random.randn(12, 4)
369
+ elevations_deg = [-0.05/2*np.pi*360] * 21
370
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
371
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
372
+ for i in Face_yaws:
373
+ if i<0:
374
+ i = 2*np.pi+i
375
+ i = i/2*np.pi*360
376
+ face_yaws = [Face_yaws[0]]
377
+ for i in range(20):
378
+ face_yaws.append(Face_yaws[3*i+2])
379
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
380
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
381
+ self.azimuths_rad[:-1].sort()
382
+
383
+ def __getitem__(self, index):
384
+ """Returns one data pair (source and target)."""
385
+ # seq_len, fea_dim
386
+ random_number1 = random.randrange(0, 21)
387
+ random_number3 = random.randrange(0, 21)
388
+ random_number2 = random.randrange(0, 21)
389
+
390
+ while random_number3 == random_number1:
391
+ random_number3 = random.randrange(0, 21)
392
+
393
+ # while random_number3 == random_number1:
394
+ # random_number3 = random.randrange(0, 21)
395
+ name = self.lists[index]
396
+
397
+ #random_number1 = random_number1
398
+ #* 10
399
+ #random_number2 = random_number2 * 10
400
+
401
+ #random_number2 = random_number1
402
+
403
+
404
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
405
+ #hair_path2 = os.path.join(self.img_path, name, str(random_number3) + '.jpg')
406
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
407
+ #non_hair_path2 = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
408
+ #non_hair_path3 = os.path.join(self.non_hair_path, name, str(random_number3) + '.jpg')
409
+ ref_folder = os.path.join(self.ref_path, name)
410
+
411
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
412
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
413
+ ref_path = os.path.join(ref_folder,files[0])
414
+ # print('________')
415
+ # print(files)
416
+ # print('++++++++')
417
+ # print(ref_folder)
418
+ # print("========")
419
+ # print(name)
420
+ # print("********")
421
+ # print(ref_path)
422
+ img_hair = cv2.imread(hair_path)
423
+ #img_hair2 = cv2.imread(hair_path2)
424
+ img_non_hair = cv2.imread(non_hair_path)
425
+ #img_non_hair2 = cv2.imread(non_hair_path2)
426
+ #img_non_hair3 = cv2.imread(non_hair_path3)
427
+ ref_hair = cv2.imread(ref_path)
428
+
429
+
430
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
431
+ #img_non_hair2 = cv2.cvtColor(img_non_hair2, cv2.COLOR_BGR2RGB)
432
+ #img_non_hair3 = cv2.cvtColor(img_non_hair3, cv2.COLOR_BGR2RGB)
433
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
434
+
435
+
436
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
437
+ #img_non_hair2 = cv2.resize(img_non_hair2, (512, 512))
438
+ #img_non_hair3 = cv2.resize(img_non_hair3, (512, 512))
439
+ ref_hair = cv2.resize(ref_hair, (512, 512))
440
+
441
+
442
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
443
+ #img_non_hair2 = (img_non_hair2 / 255.0) * 2 - 1
444
+ #img_non_hair3 = (img_non_hair3 / 255.0) * 2 - 1
445
+ ref_hair = (ref_hair / 255.0) * 2 - 1
446
+
447
+
448
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
449
+ #img_non_hair2 = torch.tensor(img_non_hair2).permute(2, 0, 1)
450
+ #img_non_hair3 = torch.tensor(img_non_hair3).permute(2, 0, 1)
451
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
452
+
453
+ pose = self.pose[random_number1]
454
+ pose = torch.tensor(pose)
455
+ pose2 = self.pose[random_number3]
456
+ pose2 = torch.tensor(pose2)
457
+ # pose2 = self.pose[random_number2]
458
+ # pose2 = torch.tensor(pose2)
459
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
460
+ # hair_num = [0, 2, 6, 14, 18, 21]
461
+ # img_hair_stack = []
462
+ # polar = self.polars_rad[random_number1]
463
+ # polar = torch.tensor(polar).unsqueeze(0)
464
+ # azimuths = self.azimuths_rad[random_number1]
465
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
466
+ # img_hair = cv2.imread(hair_path)
467
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
468
+ img_hair = cv2.resize(img_hair, (512, 512))
469
+ img_hair = (img_hair / 255.0) * 2 - 1
470
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
471
+
472
+ #img_hair2 = cv2.cvtColor(img_hair2, cv2.COLOR_BGR2RGB)
473
+ #img_hair2 = cv2.resize(img_hair2, (512, 512))
474
+ #img_hair2 = (img_hair2 / 255.0) * 2 - 1
475
+ #img_hair2 = torch.tensor(img_hair2).permute(2, 0, 1)
476
+ # begin = random.randrange(0, 21-self.frame_num)
477
+ # hair_num = [i+begin for i in range(self.frame_num)]
478
+ # for i in hair_num:
479
+ # img_hair = cv2.imread(hair_path)
480
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
481
+ # img_hair = cv2.resize(img_hair, (512, 512))
482
+ # img_hair = (img_hair / 255.0) * 2 - 1
483
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
484
+ # img_hair_stack.append(img_hair)
485
+ # img_hair = torch.stack(img_hair_stack)
486
+
487
+ return {
488
+ # 'hair_pose': pose1,
489
+ 'img_hair': img_hair,
490
+ #'img_hair2': img_hair2,
491
+ # 'bald_pose': pose2,
492
+ #'pose': pose,
493
+ #'pose2': pose2,
494
+ 'img_non_hair': img_non_hair,
495
+ #'img_non_hair2': img_non_hair2,
496
+ #'img_non_hair3': img_non_hair3,
497
+ 'ref_hair': ref_hair,
498
+ # 'polar': polar,
499
+ # 'azimuths':azimuths,
500
+ }
501
+
502
+ def __len__(self):
503
+ return self.len
504
+
505
+ class myDataset_sv3d2(data.Dataset):
506
+ """Custom data.Dataset compatible with data.DataLoader."""
507
+
508
+ def __init__(self, train_data_dir, frame_num=6):
509
+ self.img_path = os.path.join(train_data_dir, "hair")
510
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
511
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
512
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
513
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
514
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
515
+ self.ref_path = os.path.join(train_data_dir, "reference")
516
+ self.lists = os.listdir(self.img_path)
517
+ self.len = len(self.lists)-10
518
+ self.pose = np.load(self.pose_path)
519
+ self.frame_num = frame_num
520
+ #self.pose = np.random.randn(12, 4)
521
+ elevations_deg = [-0.05/2*np.pi*360] * 21
522
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
523
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
524
+ for i in Face_yaws:
525
+ if i<0:
526
+ i = 2*np.pi+i
527
+ i = i/2*np.pi*360
528
+ face_yaws = [Face_yaws[0]]
529
+ for i in range(20):
530
+ face_yaws.append(Face_yaws[3*i+2])
531
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
532
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
533
+ self.azimuths_rad[:-1].sort()
534
+
535
+ def __getitem__(self, index):
536
+ """Returns one data pair (source and target)."""
537
+ # seq_len, fea_dim
538
+ random_number1 = random.randrange(0, 21)
539
+ random_number2 = random.randrange(0, 21)
540
+
541
+ # while random_number2 == random_number1:
542
+ # random_number2 = random.randrange(0, 21)
543
+ name = self.lists[index]
544
+
545
+ #random_number1 = random_number1
546
+ #* 10
547
+ #random_number2 = random_number2 * 10
548
+
549
+ #random_number2 = random_number1
550
+
551
+
552
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
553
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
554
+ ref_folder = os.path.join(self.ref_path, name)
555
+
556
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
557
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
558
+ ref_path = os.path.join(ref_folder,files[0])
559
+ # print('________')
560
+ # print(files)
561
+ # print('++++++++')
562
+ # print(ref_folder)
563
+ # print("========")
564
+ # print(name)
565
+ # print("********")
566
+ # print(ref_path)
567
+ img_hair = cv2.imread(hair_path)
568
+ img_non_hair = cv2.imread(non_hair_path)
569
+ ref_hair = cv2.imread(ref_path)
570
+
571
+
572
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
573
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
574
+
575
+
576
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
577
+ ref_hair = cv2.resize(ref_hair, (512, 512))
578
+
579
+
580
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
581
+ ref_hair = (ref_hair / 255.0) * 2 - 1
582
+
583
+
584
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
585
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
586
+
587
+ pose = self.pose[random_number1]
588
+ pose = torch.tensor(pose)
589
+ # pose2 = self.pose[random_number2]
590
+ # pose2 = torch.tensor(pose2)
591
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
592
+ # hair_num = [0, 2, 6, 14, 18, 21]
593
+ # img_hair_stack = []
594
+ # polar = self.polars_rad[random_number1]
595
+ # polar = torch.tensor(polar).unsqueeze(0)
596
+ # azimuths = self.azimuths_rad[random_number1]
597
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
598
+ img_hair = cv2.imread(hair_path)
599
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
600
+ img_hair = cv2.resize(img_hair, (512, 512))
601
+ img_hair = (img_hair / 255.0) * 2 - 1
602
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
603
+ # begin = random.randrange(0, 21-self.frame_num)
604
+ # hair_num = [i+begin for i in range(self.frame_num)]
605
+ # for i in hair_num:
606
+ # img_hair = cv2.imread(hair_path)
607
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
608
+ # img_hair = cv2.resize(img_hair, (512, 512))
609
+ # img_hair = (img_hair / 255.0) * 2 - 1
610
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
611
+ # img_hair_stack.append(img_hair)
612
+ # img_hair = torch.stack(img_hair_stack)
613
+
614
+ return {
615
+ # 'hair_pose': pose1,
616
+ 'img_hair': img_hair,
617
+ # 'bald_pose': pose2,
618
+ 'pose': pose,
619
+ 'img_non_hair': img_non_hair,
620
+ 'ref_hair': ref_hair,
621
+ # 'polar': polar,
622
+ # 'azimuths':azimuths,
623
+ }
624
+
625
+ def __len__(self):
626
+ return self.len
627
+
628
+ class myDataset_sv3d_temporal(data.Dataset):
629
+ """Custom data.Dataset compatible with data.DataLoader."""
630
+
631
+ def __init__(self, train_data_dir, frame_num=6):
632
+ self.img_path = os.path.join(train_data_dir, "hair")
633
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
634
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
635
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
636
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
637
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
638
+ self.ref_path = os.path.join(train_data_dir, "reference")
639
+ self.lists = os.listdir(self.img_path)
640
+ self.len = len(self.lists)-10
641
+ self.pose = np.load(self.pose_path)
642
+ self.frame_num = frame_num
643
+ #self.pose = np.random.randn(12, 4)
644
+ elevations_deg = [-0.05/2*np.pi*360] * 21
645
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
646
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
647
+ for i in Face_yaws:
648
+ if i<0:
649
+ i = 2*np.pi+i
650
+ i = i/2*np.pi*360
651
+ face_yaws = [Face_yaws[0]]
652
+ for i in range(20):
653
+ face_yaws.append(Face_yaws[3*i+2])
654
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
655
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
656
+ self.azimuths_rad[:-1].sort()
657
+
658
+ def read_img(self, path):
659
+ img = cv2.imread(path)
660
+
661
+ img = cv2.resize(img, (512, 512))
662
+ img = (img / 255.0) * 2 - 1
663
+ img = torch.tensor(img).permute(2, 0, 1)
664
+ return img
665
+
666
+
667
+ def __getitem__(self, index):
668
+ """Returns one data pair (source and target)."""
669
+ # seq_len, fea_dim
670
+ random_number1 = random.randrange(0, 21-10)
671
+ # random_number3 = random.randrange(0, 21)
672
+ random_number2 = random.randrange(0, 21)
673
+
674
+ while random_number3 == random_number1:
675
+ random_number3 = random.randrange(0, 21)
676
+
677
+ # while random_number3 == random_number1:
678
+ # random_number3 = random.randrange(0, 21)
679
+ name = self.lists[index]
680
+ x_stack = []
681
+ y_stack = []
682
+ img_non_hair_stack = []
683
+ img_hair_stack = []
684
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
685
+ ref_folder = os.path.join(self.ref_path, name)
686
+
687
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
688
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
689
+ ref_path = os.path.join(ref_folder,files[0])
690
+ for i in range(10):
691
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(1))
692
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
693
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(1))
694
+
695
+ #random_number1 = random_number1
696
+ #* 10
697
+ #random_number2 = random_number2 * 10
698
+
699
+ #random_number2 = random_number1
700
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
701
+
702
+
703
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
704
+ hair_path2 = os.path.join(self.img_path, name, str(random_number3) + '.jpg')
705
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
706
+ non_hair_path2 = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
707
+ non_hair_path3 = os.path.join(self.non_hair_path, name, str(random_number3) + '.jpg')
708
+ ref_folder = os.path.join(self.ref_path, name)
709
+
710
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
711
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
712
+ ref_path = os.path.join(ref_folder,files[0])
713
+ # print('________')
714
+ # print(files)
715
+ # print('++++++++')
716
+ # print(ref_folder)
717
+ # print("========")
718
+ # print(name)
719
+ # print("********")
720
+ # print(ref_path)
721
+ img_hair = cv2.imread(hair_path)
722
+ img_hair2 = cv2.imread(hair_path2)
723
+ img_non_hair = cv2.imread(non_hair_path)
724
+ img_non_hair2 = cv2.imread(non_hair_path2)
725
+ img_non_hair3 = cv2.imread(non_hair_path3)
726
+ ref_hair = cv2.imread(ref_path)
727
+
728
+
729
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
730
+ img_non_hair2 = cv2.cvtColor(img_non_hair2, cv2.COLOR_BGR2RGB)
731
+ img_non_hair3 = cv2.cvtColor(img_non_hair3, cv2.COLOR_BGR2RGB)
732
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
733
+
734
+
735
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
736
+ img_non_hair2 = cv2.resize(img_non_hair2, (512, 512))
737
+ img_non_hair3 = cv2.resize(img_non_hair3, (512, 512))
738
+ ref_hair = cv2.resize(ref_hair, (512, 512))
739
+
740
+
741
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
742
+ img_non_hair2 = (img_non_hair2 / 255.0) * 2 - 1
743
+ img_non_hair3 = (img_non_hair3 / 255.0) * 2 - 1
744
+ ref_hair = (ref_hair / 255.0) * 2 - 1
745
+
746
+
747
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
748
+ img_non_hair2 = torch.tensor(img_non_hair2).permute(2, 0, 1)
749
+ img_non_hair3 = torch.tensor(img_non_hair3).permute(2, 0, 1)
750
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
751
+
752
+ pose = self.pose[random_number1]
753
+ pose = torch.tensor(pose)
754
+ pose2 = self.pose[random_number3]
755
+ pose2 = torch.tensor(pose2)
756
+ # pose2 = self.pose[random_number2]
757
+ # pose2 = torch.tensor(pose2)
758
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
759
+ # hair_num = [0, 2, 6, 14, 18, 21]
760
+ # img_hair_stack = []
761
+ # polar = self.polars_rad[random_number1]
762
+ # polar = torch.tensor(polar).unsqueeze(0)
763
+ # azimuths = self.azimuths_rad[random_number1]
764
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
765
+ # img_hair = cv2.imread(hair_path)
766
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
767
+ img_hair = cv2.resize(img_hair, (512, 512))
768
+ img_hair = (img_hair / 255.0) * 2 - 1
769
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
770
+
771
+ img_hair2 = cv2.cvtColor(img_hair2, cv2.COLOR_BGR2RGB)
772
+ img_hair2 = cv2.resize(img_hair2, (512, 512))
773
+ img_hair2 = (img_hair2 / 255.0) * 2 - 1
774
+ img_hair2 = torch.tensor(img_hair2).permute(2, 0, 1)
775
+ # begin = random.randrange(0, 21-self.frame_num)
776
+ # hair_num = [i+begin for i in range(self.frame_num)]
777
+ # for i in hair_num:
778
+ # img_hair = cv2.imread(hair_path)
779
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
780
+ # img_hair = cv2.resize(img_hair, (512, 512))
781
+ # img_hair = (img_hair / 255.0) * 2 - 1
782
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
783
+ # img_hair_stack.append(img_hair)
784
+ # img_hair = torch.stack(img_hair_stack)
785
+
786
+ return {
787
+ # 'hair_pose': pose1,
788
+ 'img_hair': img_hair,
789
+ 'img_hair2': img_hair2,
790
+ # 'bald_pose': pose2,
791
+ 'pose': pose,
792
+ 'pose2': pose2,
793
+ 'img_non_hair': img_non_hair,
794
+ 'img_non_hair2': img_non_hair2,
795
+ 'img_non_hair3': img_non_hair3,
796
+ 'ref_hair': ref_hair,
797
+ # 'polar': polar,
798
+ # 'azimuths':azimuths,
799
+ }
800
+
801
+ def __len__(self):
802
+ return self.len
803
+
804
+ class myDataset_sv3d_simple(data.Dataset):
805
+ """Custom data.Dataset compatible with data.DataLoader."""
806
+
807
+ def __init__(self, train_data_dir, frame_num=6):
808
+ self.img_path = os.path.join(train_data_dir, "hair")
809
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
810
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
811
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
812
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
813
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
814
+ # self.ref_path = os.path.join(train_data_dir, "reference")
815
+ self.ref_path = os.path.join(train_data_dir, "reference")
816
+ self.lists = os.listdir(self.img_path)
817
+ self.len = len(self.lists)-10
818
+ self.pose = np.load(self.pose_path)
819
+ self.frame_num = frame_num
820
+ #self.pose = np.random.randn(12, 4)
821
+ elevations_deg = [-0.05/2*np.pi*360] * 21
822
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
823
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
824
+ for i in Face_yaws:
825
+ if i<0:
826
+ i = 2*np.pi+i
827
+ i = i/2*np.pi*360
828
+ face_yaws = [Face_yaws[0]]
829
+ for i in range(20):
830
+ face_yaws.append(Face_yaws[3*i+2])
831
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
832
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
833
+ self.azimuths_rad[:-1].sort()
834
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
835
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
836
+ self.x = [x[0]]
837
+ self.y = [y[0]]
838
+ for i in range(20):
839
+ self.x.append(x[i*3+2])
840
+ self.y.append(y[i*3+2])
841
+
842
+ def __getitem__(self, index):
843
+ """Returns one data pair (source and target)."""
844
+ # seq_len, fea_dim
845
+ random_number1 = random.randrange(0, 21)
846
+ random_number2 = random.randrange(0, 21)
847
+
848
+ # while random_number2 == random_number1:
849
+ # random_number2 = random.randrange(0, 21)
850
+ name = self.lists[index]
851
+
852
+ #random_number1 = random_number1
853
+ #* 10
854
+ #random_number2 = random_number2 * 10
855
+
856
+ random_number2 = random_number1
857
+
858
+
859
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
860
+ # hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
861
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
862
+ ref_folder = os.path.join(self.ref_path, name)
863
+
864
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
865
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
866
+ ref_path = os.path.join(ref_folder,files[0])
867
+ # print('________')
868
+ # print(files)
869
+ # print('++++++++')
870
+ # print(ref_folder)
871
+ # print("========")
872
+ # print(name)
873
+ # print("********")
874
+ # print(ref_path)
875
+ img_hair = cv2.imread(hair_path)
876
+ img_non_hair = cv2.imread(non_hair_path)
877
+ ref_hair = cv2.imread(ref_path)
878
+
879
+
880
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
881
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
882
+
883
+
884
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
885
+ ref_hair = cv2.resize(ref_hair, (512, 512))
886
+
887
+
888
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
889
+ ref_hair = (ref_hair / 255.0) * 2 - 1
890
+
891
+
892
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
893
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
894
+
895
+ pose = self.pose[random_number1]
896
+ pose = torch.tensor(pose)
897
+ # pose2 = self.pose[random_number2]
898
+ # pose2 = torch.tensor(pose2)
899
+ # hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
900
+ # hair_num = [0, 2, 6, 14, 18, 21]
901
+ # img_hair_stack = []
902
+ # polar = self.polars_rad[random_number1]
903
+ # polar = torch.tensor(polar).unsqueeze(0)
904
+ # azimuths = self.azimuths_rad[random_number1]
905
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
906
+ img_hair = cv2.imread(hair_path)
907
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
908
+ img_hair = cv2.resize(img_hair, (512, 512))
909
+ img_hair = (img_hair / 255.0) * 2 - 1
910
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
911
+ x = torch.tensor(self.x[random_number1])
912
+ y = torch.tensor(self.y[random_number1])
913
+ # begin = random.randrange(0, 21-self.frame_num)
914
+ # hair_num = [i+begin for i in range(self.frame_num)]
915
+ # for i in hair_num:
916
+ # img_hair = cv2.imread(hair_path)
917
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
918
+ # img_hair = cv2.resize(img_hair, (512, 512))
919
+ # img_hair = (img_hair / 255.0) * 2 - 1
920
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
921
+ # img_hair_stack.append(img_hair)
922
+ # img_hair = torch.stack(img_hair_stack)
923
+
924
+ return {
925
+ # 'hair_pose': pose1,
926
+ 'img_hair': img_hair,
927
+ # 'bald_pose': pose2,
928
+ 'pose': pose,
929
+ 'img_non_hair': img_non_hair,
930
+ 'ref_hair': ref_hair,
931
+ 'x': x,
932
+ 'y': y,
933
+ # 'polar': polar,
934
+ # 'azimuths':azimuths,
935
+ }
936
+
937
+ def __len__(self):
938
+ return self.len
939
+
940
+ class myDataset_sv3d_simple_ori(data.Dataset):
941
+ """Custom data.Dataset compatible with data.DataLoader."""
942
+
943
+ def __init__(self, train_data_dir, frame_num=6):
944
+ self.img_path = os.path.join(train_data_dir, "hair")
945
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
946
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
947
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
948
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
949
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
950
+ self.ref_path = os.path.join(train_data_dir, "reference")
951
+ self.lists = os.listdir(self.img_path)
952
+ self.len = len(self.lists)-10
953
+ self.pose = np.load(self.pose_path)
954
+ self.frame_num = frame_num
955
+ #self.pose = np.random.randn(12, 4)
956
+ elevations_deg = [-0.05/2*np.pi*360] * 21
957
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
958
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
959
+ for i in Face_yaws:
960
+ if i<0:
961
+ i = 2*np.pi+i
962
+ i = i/2*np.pi*360
963
+ face_yaws = [Face_yaws[0]]
964
+ for i in range(20):
965
+ face_yaws.append(Face_yaws[3*i+2])
966
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
967
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
968
+ self.azimuths_rad[:-1].sort()
969
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
970
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
971
+ self.x = [x[0]]
972
+ self.y = [y[0]]
973
+ for i in range(20):
974
+ self.x.append(x[i*3+2])
975
+ self.y.append(y[i*3+2])
976
+
977
+ def __getitem__(self, index):
978
+ """Returns one data pair (source and target)."""
979
+ # seq_len, fea_dim
980
+ random_number1 = random.randrange(0, 21)
981
+ random_number2 = random.randrange(0, 21)
982
+
983
+ # while random_number2 == random_number1:
984
+ # random_number2 = random.randrange(0, 21)
985
+ name = self.lists[index]
986
+
987
+ #random_number1 = random_number1
988
+ #* 10
989
+ #random_number2 = random_number2 * 10
990
+
991
+ #random_number2 = random_number1
992
+
993
+
994
+ hair_path = os.path.join(self.img_path, name, str(random_number2) + '.jpg')
995
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
996
+ ref_folder = os.path.join(self.ref_path, name)
997
+
998
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
999
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1000
+ ref_path = os.path.join(ref_folder,files[0])
1001
+ # print('________')
1002
+ # print(files)
1003
+ # print('++++++++')
1004
+ # print(ref_folder)
1005
+ # print("========")
1006
+ # print(name)
1007
+ # print("********")
1008
+ # print(ref_path)
1009
+ img_hair = cv2.imread(hair_path)
1010
+ img_non_hair = cv2.imread(non_hair_path)
1011
+ ref_hair = cv2.imread(ref_path)
1012
+
1013
+
1014
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
1015
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
1016
+
1017
+
1018
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
1019
+ ref_hair = cv2.resize(ref_hair, (512, 512))
1020
+
1021
+
1022
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
1023
+ ref_hair = (ref_hair / 255.0) * 2 - 1
1024
+
1025
+
1026
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
1027
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
1028
+
1029
+ pose = self.pose[random_number2]
1030
+ pose = torch.tensor(pose)
1031
+ # pose2 = self.pose[random_number2]
1032
+ # pose2 = torch.tensor(pose2)
1033
+ hair_path = os.path.join(self.img_path, name, str(random_number2) + '.jpg')
1034
+ # hair_num = [0, 2, 6, 14, 18, 21]
1035
+ # img_hair_stack = []
1036
+ # polar = self.polars_rad[random_number1]
1037
+ # polar = torch.tensor(polar).unsqueeze(0)
1038
+ # azimuths = self.azimuths_rad[random_number1]
1039
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
1040
+ img_hair = cv2.imread(hair_path)
1041
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
1042
+ img_hair = cv2.resize(img_hair, (512, 512))
1043
+ img_hair = (img_hair / 255.0) * 2 - 1
1044
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
1045
+ x = torch.tensor(self.x[random_number2])
1046
+ y = torch.tensor(self.y[random_number2])
1047
+ # begin = random.randrange(0, 21-self.frame_num)
1048
+ # hair_num = [i+begin for i in range(self.frame_num)]
1049
+ # for i in hair_num:
1050
+ # img_hair = cv2.imread(hair_path)
1051
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
1052
+ # img_hair = cv2.resize(img_hair, (512, 512))
1053
+ # img_hair = (img_hair / 255.0) * 2 - 1
1054
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
1055
+ # img_hair_stack.append(img_hair)
1056
+ # img_hair = torch.stack(img_hair_stack)
1057
+
1058
+ return {
1059
+ # 'hair_pose': pose1,
1060
+ 'img_hair': img_hair,
1061
+ # 'bald_pose': pose2,
1062
+ 'pose': pose,
1063
+ 'img_non_hair': img_non_hair,
1064
+ 'ref_hair': ref_hair,
1065
+ 'x': x,
1066
+ 'y': y,
1067
+ # 'polar': polar,
1068
+ # 'azimuths':azimuths,
1069
+ }
1070
+
1071
+ def __len__(self):
1072
+ return self.len
1073
+
1074
+ class myDataset_sv3d_simple_temporal(data.Dataset):
1075
+ """Custom data.Dataset compatible with data.DataLoader."""
1076
+
1077
+ def __init__(self, train_data_dir, frame_num=6):
1078
+ self.img_path = os.path.join(train_data_dir, "hair")
1079
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1080
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1081
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1082
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1083
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1084
+ self.ref_path = os.path.join(train_data_dir, "reference")
1085
+ self.lists = os.listdir(self.img_path)
1086
+ self.len = len(self.lists)-10
1087
+ self.pose = np.load(self.pose_path)
1088
+ self.frame_num = frame_num
1089
+ #self.pose = np.random.randn(12, 4)
1090
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1091
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1092
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1093
+ for i in Face_yaws:
1094
+ if i<0:
1095
+ i = 2*np.pi+i
1096
+ i = i/2*np.pi*360
1097
+ face_yaws = [Face_yaws[0]]
1098
+ for i in range(20):
1099
+ face_yaws.append(Face_yaws[3*i+2])
1100
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1101
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1102
+ self.azimuths_rad[:-1].sort()
1103
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1104
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1105
+ self.x = [x[0]]
1106
+ self.y = [y[0]]
1107
+ for i in range(20):
1108
+ self.x.append(x[i*3+2])
1109
+ self.y.append(y[i*3+2])
1110
+
1111
+ def read_img(self, path):
1112
+ img = cv2.imread(path)
1113
+
1114
+ img = cv2.resize(img, (512, 512))
1115
+ img = (img / 255.0) * 2 - 1
1116
+ img = torch.tensor(img).permute(2, 0, 1)
1117
+ return img
1118
+
1119
+ def __getitem__(self, index):
1120
+ """Returns one data pair (source and target)."""
1121
+ # seq_len, fea_dim
1122
+ random_number1 = random.randrange(0, 21-12)
1123
+ random_number2 = random.randrange(0, 21)
1124
+
1125
+ name = self.lists[index]
1126
+ x_stack = []
1127
+ y_stack = []
1128
+ img_non_hair_stack = []
1129
+ img_hair_stack = []
1130
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1131
+ ref_folder = os.path.join(self.ref_path, name)
1132
+
1133
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1134
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1135
+ ref_path = os.path.join(ref_folder,files[0])
1136
+ ref_hair = self.read_img(ref_path)
1137
+ for i in range(12):
1138
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1139
+ #hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1140
+ hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
1141
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1142
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1143
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1144
+
1145
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1146
+ img_hair = torch.cat(img_hair_stack, axis=0)
1147
+ x = torch.cat(x_stack, axis=0)
1148
+ y = torch.cat(y_stack, axis=0)
1149
+
1150
+ return {
1151
+ 'img_hair': img_hair,
1152
+ 'img_non_hair': img_non_hair,
1153
+ 'ref_hair': ref_hair,
1154
+ 'x': x,
1155
+ 'y': y,
1156
+
1157
+ }
1158
+
1159
+ def __len__(self):
1160
+ return self.len
1161
+
1162
+ class myDataset_sv3d_simple_temporal2(data.Dataset):
1163
+ """Custom data.Dataset compatible with data.DataLoader."""
1164
+
1165
+ def __init__(self, train_data_dir, frame_num=6):
1166
+ train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
1167
+ self.img_path = os.path.join(train_data_dir, "hair")
1168
+ self.img_path2 = os.path.join(train_data_dir, "hair_good")
1169
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1170
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1171
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1172
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1173
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1174
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1175
+
1176
+ self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
1177
+ self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
1178
+ self.ref_path2 = os.path.join(train_data_dir2, "reference")
1179
+
1180
+ self.lists = os.listdir(self.img_path2)
1181
+ self.len = len(self.lists)-10
1182
+ self.pose = np.load(self.pose_path)
1183
+ self.frame_num = frame_num
1184
+ #self.pose = np.random.randn(12, 4)
1185
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1186
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1187
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1188
+ for i in Face_yaws:
1189
+ if i<0:
1190
+ i = 2*np.pi+i
1191
+ i = i/2*np.pi*360
1192
+ face_yaws = [Face_yaws[0]]
1193
+ for i in range(20):
1194
+ face_yaws.append(Face_yaws[3*i+2])
1195
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1196
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1197
+ self.azimuths_rad[:-1].sort()
1198
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1199
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1200
+ self.x = [x[0]]
1201
+ self.y = [y[0]]
1202
+ for i in range(20):
1203
+ self.x.append(x[i*3+2])
1204
+ self.y.append(y[i*3+2])
1205
+
1206
+ def read_img(self, path):
1207
+ img = cv2.imread(path)
1208
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1209
+ img = cv2.resize(img, (512, 512))
1210
+ img = (img / 255.0) * 2 - 1
1211
+ img = torch.tensor(img).permute(2, 0, 1)
1212
+ return img
1213
+
1214
+ def read_ref_img(self, path):
1215
+ img = cv2.imread(path)
1216
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1217
+ img = hair_transform(image=img)['image']
1218
+ # img = cv2.resize(img, (512, 512))
1219
+ img = (img / 255.0) * 2 - 1
1220
+ img = torch.tensor(img).permute(2, 0, 1)
1221
+ return img
1222
+
1223
+ def reference_lists(self, reference_num, root):
1224
+ stacks = []
1225
+ invalid = []
1226
+ for i in range(12):
1227
+ if (reference_num-6+i)<0:
1228
+ invalid.append(reference_num-5+i+21)
1229
+ else:
1230
+ invalid.append((reference_num-5+i)%21)
1231
+ for i in range(21):
1232
+ if i in invalid:
1233
+ continue
1234
+ else:
1235
+ stacks.append(os.path.join(root, str(i)+'.jpg'))
1236
+ return stacks
1237
+
1238
+ def __getitem__(self, index):
1239
+ """Returns one data pair (source and target)."""
1240
+ # seq_len, fea_dim
1241
+ random_number = random.uniform(0, 1)
1242
+ if random_number<0.5:
1243
+ non_hair_root = self.non_hair_path2
1244
+ img_path = self.img_path2
1245
+ else:
1246
+ non_hair_root = self.non_hair_path
1247
+ img_path = self.img_path
1248
+ random_number1 = random.randrange(0, 21-12)
1249
+ random_number2 = random.randrange(0, 21)
1250
+
1251
+ name = self.lists[index].split('.')[0]
1252
+ x_stack = []
1253
+ y_stack = []
1254
+ img_non_hair_stack = []
1255
+ img_hair_stack = []
1256
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1257
+ # non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
1258
+ ref_folder = os.path.join(self.ref_path, name)
1259
+
1260
+ # files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3] + self.reference_lists(random_number2, os.path.join(self.img_path, name))[:5]
1261
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
1262
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1263
+ ref_path = os.path.join(ref_folder,random.choice(files))
1264
+ ref_hair = self.read_ref_img(ref_path)
1265
+ for i in range(12):
1266
+ #non_hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
1267
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1268
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1269
+ # hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
1270
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1271
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1272
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1273
+
1274
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1275
+ img_hair = torch.cat(img_hair_stack, axis=0)
1276
+ x = torch.cat(x_stack, axis=0)
1277
+ y = torch.cat(y_stack, axis=0)
1278
+
1279
+ return {
1280
+ 'img_hair': img_hair,
1281
+ 'img_non_hair': img_non_hair,
1282
+ 'ref_hair': ref_hair,
1283
+ 'x': x,
1284
+ 'y': y,
1285
+
1286
+ }
1287
+
1288
+ def __len__(self):
1289
+ return self.len
1290
+
1291
+ class myDataset_sv3d_simple_temporal3(data.Dataset):
1292
+ """Custom data.Dataset compatible with data.DataLoader."""
1293
+
1294
+ def __init__(self, train_data_dir, frame_num=6):
1295
+ train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
1296
+ self.img_path = os.path.join(train_data_dir, "hair")
1297
+ self.img_path2 = os.path.join(train_data_dir, "non-hair")
1298
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1299
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1300
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1301
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1302
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1303
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1304
+
1305
+ self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
1306
+ self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
1307
+ self.ref_path2 = os.path.join(train_data_dir2, "reference")
1308
+
1309
+ self.lists = os.listdir(self.img_path)
1310
+ self.len = len(self.lists)-10
1311
+ self.pose = np.load(self.pose_path)
1312
+ self.frame_num = frame_num
1313
+ #self.pose = np.random.randn(12, 4)
1314
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1315
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1316
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1317
+ for i in Face_yaws:
1318
+ if i<0:
1319
+ i = 2*np.pi+i
1320
+ i = i/2*np.pi*360
1321
+ face_yaws = [Face_yaws[0]]
1322
+ for i in range(20):
1323
+ face_yaws.append(Face_yaws[3*i+2])
1324
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1325
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1326
+ self.azimuths_rad[:-1].sort()
1327
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1328
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1329
+ self.x = [x[0]]
1330
+ self.y = [y[0]]
1331
+ for i in range(20):
1332
+ self.x.append(x[i*3+2])
1333
+ self.y.append(y[i*3+2])
1334
+
1335
+ def read_img(self, path):
1336
+ img = cv2.imread(path)
1337
+
1338
+ img = cv2.resize(img, (512, 512))
1339
+ img = (img / 255.0) * 2 - 1
1340
+ img = torch.tensor(img).permute(2, 0, 1)
1341
+ return img
1342
+
1343
+ def __getitem__(self, index):
1344
+ """Returns one data pair (source and target)."""
1345
+ # seq_len, fea_dim
1346
+ random_number = random.uniform(0, 1)
1347
+ if random_number<0.5:
1348
+ non_hair_root = self.non_hair_path2
1349
+ img_path = self.img_path2
1350
+ else:
1351
+ non_hair_root = self.non_hair_path
1352
+ img_path = self.img_path
1353
+ random_number1 = random.randrange(0, 21-12)
1354
+ random_number2 = random.randrange(0, 21)
1355
+
1356
+ name = self.lists[index]
1357
+ x_stack = []
1358
+ y_stack = []
1359
+ img_non_hair_stack = []
1360
+ img_hair_stack = []
1361
+ # non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1362
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1363
+ ref_folder = os.path.join(self.ref_path, name)
1364
+
1365
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1366
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1367
+ ref_path = os.path.join(ref_folder,random.choice(files))
1368
+ ref_hair = self.read_img(ref_path)
1369
+ for i in range(12):
1370
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
1371
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1372
+ # hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1373
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1374
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1375
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1376
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1377
+
1378
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1379
+ img_hair = torch.cat(img_hair_stack, axis=0)
1380
+ x = torch.cat(x_stack, axis=0)
1381
+ y = torch.cat(y_stack, axis=0)
1382
+
1383
+ return {
1384
+ 'img_hair': img_hair,
1385
+ 'img_non_hair': img_non_hair,
1386
+ 'ref_hair': ref_hair,
1387
+ 'x': x,
1388
+ 'y': y,
1389
+
1390
+ }
1391
+
1392
+ def __len__(self):
1393
+ return self.len
1394
+
1395
+
1396
+ class myDataset_sv3d_simple_temporal_controlnet_without_pose(data.Dataset):
1397
+ """Custom data.Dataset compatible with data.DataLoader."""
1398
+
1399
+ def __init__(self, train_data_dir, frame_num=6):
1400
+ train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
1401
+ self.img_path = os.path.join(train_data_dir, "hair")
1402
+ self.img_path2 = os.path.join(train_data_dir, "non-hair")
1403
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1404
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1405
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1406
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1407
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1408
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1409
+
1410
+ self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
1411
+ self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
1412
+ self.ref_path2 = os.path.join(train_data_dir2, "reference")
1413
+
1414
+ self.lists = os.listdir(self.img_path)
1415
+ self.len = len(self.lists)-10
1416
+ self.pose = np.load(self.pose_path)
1417
+ self.frame_num = frame_num
1418
+ #self.pose = np.random.randn(12, 4)
1419
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1420
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1421
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1422
+ for i in Face_yaws:
1423
+ if i<0:
1424
+ i = 2*np.pi+i
1425
+ i = i/2*np.pi*360
1426
+ face_yaws = [Face_yaws[0]]
1427
+ for i in range(20):
1428
+ face_yaws.append(Face_yaws[3*i+2])
1429
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1430
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1431
+ self.azimuths_rad[:-1].sort()
1432
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1433
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1434
+ self.x = [x[0]]
1435
+ self.y = [y[0]]
1436
+ for i in range(20):
1437
+ self.x.append(x[i*3+2])
1438
+ self.y.append(y[i*3+2])
1439
+
1440
+ def read_img(self, path):
1441
+ img = cv2.imread(path)
1442
+
1443
+ img = cv2.resize(img, (512, 512))
1444
+ img = (img / 255.0) * 2 - 1
1445
+ img = torch.tensor(img).permute(2, 0, 1)
1446
+ return img
1447
+
1448
+ def __getitem__(self, index):
1449
+ """Returns one data pair (source and target)."""
1450
+ # seq_len, fea_dim
1451
+ random_number = random.uniform(0, 1)
1452
+ if random_number<0.5:
1453
+ non_hair_root = self.non_hair_path2
1454
+ img_path = self.img_path2
1455
+ else:
1456
+ non_hair_root = self.non_hair_path
1457
+ img_path = self.img_path
1458
+ random_number1 = random.randrange(0, 21-12)
1459
+ random_number2 = random.randrange(0, 21)
1460
+
1461
+ name = self.lists[index]
1462
+ x_stack = []
1463
+ y_stack = []
1464
+ img_non_hair_stack = []
1465
+ img_hair_stack = []
1466
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1467
+ ref_folder = os.path.join(self.ref_path, name)
1468
+
1469
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1470
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1471
+ ref_path = os.path.join(ref_folder,random.choice(files))
1472
+ ref_hair = self.read_img(ref_path)
1473
+ for i in range(12):
1474
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
1475
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1476
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1477
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1478
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1479
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1480
+
1481
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1482
+ img_hair = torch.cat(img_hair_stack, axis=0)
1483
+ x = torch.cat(x_stack, axis=0)
1484
+ y = torch.cat(y_stack, axis=0)
1485
+
1486
+ return {
1487
+ 'img_hair': img_hair,
1488
+ 'img_non_hair': img_non_hair,
1489
+ 'ref_hair': ref_hair,
1490
+ 'x': x,
1491
+ 'y': y,
1492
+
1493
+ }
1494
+
1495
+ def __len__(self):
1496
+ return self.len
1497
+
1498
+
1499
+ class myDataset_sv3d_simple_temporal_controlnet(data.Dataset):
1500
+ """Custom data.Dataset compatible with data.DataLoader."""
1501
+
1502
+ def __init__(self, train_data_dir, frame_num=6):
1503
+ train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
1504
+ self.img_path = os.path.join(train_data_dir, "hair")
1505
+ self.img_path2 = os.path.join(train_data_dir, "non-hair")
1506
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1507
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1508
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1509
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1510
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1511
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1512
+
1513
+ self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
1514
+ self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
1515
+ self.ref_path2 = os.path.join(train_data_dir2, "reference")
1516
+
1517
+ self.lists = os.listdir(self.img_path)
1518
+ self.len = len(self.lists)-10
1519
+ self.pose = np.load(self.pose_path)
1520
+ self.frame_num = frame_num
1521
+ #self.pose = np.random.randn(12, 4)
1522
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1523
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1524
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1525
+ for i in Face_yaws:
1526
+ if i<0:
1527
+ i = 2*np.pi+i
1528
+ i = i/2*np.pi*360
1529
+ face_yaws = [Face_yaws[0]]
1530
+ for i in range(20):
1531
+ face_yaws.append(Face_yaws[3*i+2])
1532
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1533
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1534
+ self.azimuths_rad[:-1].sort()
1535
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1536
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1537
+ self.x = [x[0]]
1538
+ self.y = [y[0]]
1539
+ for i in range(20):
1540
+ self.x.append(x[i*3+2])
1541
+ self.y.append(y[i*3+2])
1542
+
1543
+ def read_img(self, path):
1544
+ img = cv2.imread(path)
1545
+
1546
+ img = cv2.resize(img, (512, 512))
1547
+ img = (img / 255.0) * 2 - 1
1548
+ img = torch.tensor(img).permute(2, 0, 1)
1549
+ return img
1550
+
1551
+ def __getitem__(self, index):
1552
+ """Returns one data pair (source and target)."""
1553
+ # seq_len, fea_dim
1554
+ random_number = random.uniform(0, 1)
1555
+ if random_number<0.5:
1556
+ non_hair_root = self.non_hair_path2
1557
+ img_path = self.img_path2
1558
+ else:
1559
+ non_hair_root = self.non_hair_path
1560
+ img_path = self.img_path
1561
+ random_number1 = random.randrange(0, 21)
1562
+ random_number2 = random.randrange(0, 21)
1563
+
1564
+ name = self.lists[index]
1565
+ x_stack = []
1566
+ y_stack = []
1567
+ img_non_hair_stack = []
1568
+ img_hair_stack = []
1569
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1570
+ ref_folder = os.path.join(self.ref_path, name)
1571
+
1572
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1573
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1574
+ ref_path = os.path.join(ref_folder,random.choice(files))
1575
+ ref_hair = self.read_img(ref_path)
1576
+ non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
1577
+ img_non_hair = self.read_img(non_hair_path)
1578
+ hair_path = os.path.join(img_path, name, str(random_number1) + '.jpg')
1579
+ img_hair= self.read_img(hair_path)
1580
+ x = self.x[random_number1]
1581
+ y = self.y[random_number1]
1582
+
1583
+
1584
+
1585
+ return {
1586
+ 'img_hair': img_hair,
1587
+ 'img_non_hair': img_non_hair,
1588
+ 'ref_hair': ref_hair,
1589
+ 'x': x,
1590
+ 'y': y,
1591
+
1592
+ }
1593
+
1594
+ def __len__(self):
1595
+ return self.len
1596
+
1597
+ class myDataset_sv3d_simple_temporal_pose(data.Dataset):
1598
+ """Custom data.Dataset compatible with data.DataLoader."""
1599
+
1600
+ def __init__(self, train_data_dir, frame_num=6):
1601
+ self.img_path = os.path.join(train_data_dir, "hair")
1602
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1603
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1604
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1605
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1606
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1607
+ self.ref_path = os.path.join(train_data_dir, "reference")
1608
+ self.lists = os.listdir(self.img_path)
1609
+ self.len = len(self.lists)-10
1610
+ self.pose = np.load(self.pose_path)
1611
+ self.frame_num = frame_num
1612
+ #self.pose = np.random.randn(12, 4)
1613
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1614
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1615
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1616
+ for i in Face_yaws:
1617
+ if i<0:
1618
+ i = 2*np.pi+i
1619
+ i = i/2*np.pi*360
1620
+ face_yaws = [Face_yaws[0]]
1621
+ for i in range(20):
1622
+ face_yaws.append(Face_yaws[3*i+2])
1623
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1624
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1625
+ self.azimuths_rad[:-1].sort()
1626
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1627
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1628
+ self.x = [x[0]]
1629
+ self.y = [y[0]]
1630
+ for i in range(20):
1631
+ self.x.append(x[i*3+2])
1632
+ self.y.append(y[i*3+2])
1633
+
1634
+ def read_img(self, path):
1635
+ img = cv2.imread(path)
1636
+
1637
+ img = cv2.resize(img, (512, 512))
1638
+ img = (img / 255.0) * 2 - 1
1639
+ img = torch.tensor(img).permute(2, 0, 1)
1640
+ return img
1641
+
1642
+ def __getitem__(self, index):
1643
+ """Returns one data pair (source and target)."""
1644
+ # seq_len, fea_dim
1645
+ random_number1 = random.randrange(0, 21-12)
1646
+ random_number2 = random.randrange(0, 21)
1647
+
1648
+ name = self.lists[index]
1649
+ x_stack = []
1650
+ y_stack = []
1651
+ img_non_hair_stack = []
1652
+ img_hair_stack = []
1653
+ random_number = random.randint(0, 1)
1654
+ if random_number==0:
1655
+ img_path = self.img_path
1656
+ else:
1657
+ img_path = self.non_hair_path
1658
+ non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
1659
+ ref_folder = os.path.join(self.ref_path, name)
1660
+
1661
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1662
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1663
+ ref_path = os.path.join(ref_folder,files[0])
1664
+ ref_hair = self.read_img(ref_path)
1665
+
1666
+
1667
+ for i in range(12):
1668
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1669
+ hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
1670
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1671
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1672
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1673
+
1674
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1675
+ img_hair = torch.cat(img_hair_stack, axis=0)
1676
+ x = torch.cat(x_stack, axis=0)
1677
+ y = torch.cat(y_stack, axis=0)
1678
+
1679
+ return {
1680
+ 'img_hair': img_hair,
1681
+ 'img_non_hair': img_non_hair,
1682
+ 'ref_hair': ref_hair,
1683
+ 'x': x,
1684
+ 'y': y,
1685
+
1686
+ }
1687
+
1688
+ def __len__(self):
1689
+ return self.len
1690
+
1691
+
1692
+ class myDataset_sv3d_simple_temporal_random_reference(data.Dataset):
1693
+ """Custom data.Dataset compatible with data.DataLoader."""
1694
+
1695
+ def __init__(self, train_data_dir, frame_num=6):
1696
+ self.img_path = os.path.join(train_data_dir, "hair")
1697
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1698
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1699
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1700
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1701
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1702
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1703
+ self.lists = os.listdir(self.img_path)
1704
+ self.len = len(self.lists)-10
1705
+ self.pose = np.load(self.pose_path)
1706
+ self.frame_num = frame_num
1707
+ #self.pose = np.random.randn(12, 4)
1708
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1709
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1710
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1711
+ for i in Face_yaws:
1712
+ if i<0:
1713
+ i = 2*np.pi+i
1714
+ i = i/2*np.pi*360
1715
+ face_yaws = [Face_yaws[0]]
1716
+ for i in range(20):
1717
+ face_yaws.append(Face_yaws[3*i+2])
1718
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1719
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1720
+ self.azimuths_rad[:-1].sort()
1721
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1722
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1723
+ self.x = [x[0]]
1724
+ self.y = [y[0]]
1725
+ for i in range(20):
1726
+ self.x.append(x[i*3+2])
1727
+ self.y.append(y[i*3+2])
1728
+
1729
+ def read_img(self, path):
1730
+ img = cv2.imread(path)
1731
+
1732
+ img = cv2.resize(img, (512, 512))
1733
+ img = (img / 255.0) * 2 - 1
1734
+ img = torch.tensor(img).permute(2, 0, 1)
1735
+ return img
1736
+
1737
+ def __getitem__(self, index):
1738
+ """Returns one data pair (source and target)."""
1739
+ # seq_len, fea_dim
1740
+ random_number1 = random.randrange(0, 21-12)
1741
+ random_number2 = random.randrange(0, 21)
1742
+
1743
+ name = self.lists[index]
1744
+ x_stack = []
1745
+ y_stack = []
1746
+ img_non_hair_stack = []
1747
+ img_hair_stack = []
1748
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1749
+ ref_folder = os.path.join(self.ref_path, name)
1750
+
1751
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
1752
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1753
+ ref_path = os.path.join(ref_folder,random.choice(files))
1754
+ ref_hair = self.read_img(ref_path)
1755
+ for i in range(12):
1756
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
1757
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
1758
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
1759
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
1760
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
1761
+
1762
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
1763
+ img_hair = torch.cat(img_hair_stack, axis=0)
1764
+ x = torch.cat(x_stack, axis=0)
1765
+ y = torch.cat(y_stack, axis=0)
1766
+
1767
+ return {
1768
+ 'img_hair': img_hair,
1769
+ 'img_non_hair': img_non_hair,
1770
+ 'ref_hair': ref_hair,
1771
+ 'x': x,
1772
+ 'y': y,
1773
+
1774
+ }
1775
+
1776
+ def __len__(self):
1777
+ return self.len
1778
+
1779
+ class myDataset_sv3d_simple_random_reference(data.Dataset):
1780
+ """Custom data.Dataset compatible with data.DataLoader."""
1781
+
1782
+ def __init__(self, train_data_dir, frame_num=6):
1783
+ self.img_path = os.path.join(train_data_dir, "hair")
1784
+ self.img_path2 = os.path.join(train_data_dir, "hair_good")
1785
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
1786
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
1787
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
1788
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1789
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1790
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1791
+ # self.lists = os.listdir(self.img_path2)
1792
+ self.lists = os.listdir(self.img_path)
1793
+ self.len = len(self.lists)-10
1794
+ self.pose = np.load(self.pose_path)
1795
+ self.frame_num = frame_num
1796
+ #self.pose = np.random.randn(12, 4)
1797
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1798
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1799
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1800
+ for i in Face_yaws:
1801
+ if i<0:
1802
+ i = 2*np.pi+i
1803
+ i = i/2*np.pi*360
1804
+ face_yaws = [Face_yaws[0]]
1805
+ for i in range(20):
1806
+ face_yaws.append(Face_yaws[3*i+2])
1807
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1808
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1809
+ self.azimuths_rad[:-1].sort()
1810
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1811
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1812
+ self.x = [x[0]]
1813
+ self.y = [y[0]]
1814
+ for i in range(20):
1815
+ self.x.append(x[i*3+2])
1816
+ self.y.append(y[i*3+2])
1817
+
1818
+ def __getitem__(self, index):
1819
+ """Returns one data pair (source and target)."""
1820
+ # seq_len, fea_dim
1821
+ random_number1 = random.randrange(0, 21)
1822
+ random_number2 = random.randrange(0, 21)
1823
+
1824
+ # while random_number2 == random_number1:
1825
+ # random_number2 = random.randrange(0, 21)
1826
+ name = self.lists[index].split('.')[0]
1827
+
1828
+ #random_number1 = random_number1
1829
+ #* 10
1830
+ #random_number2 = random_number2 * 10
1831
+
1832
+ #random_number2 = random_number1
1833
+
1834
+
1835
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
1836
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1837
+ ref_folder = os.path.join(self.ref_path, name)
1838
+
1839
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
1840
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
1841
+ ref_path = os.path.join(ref_folder,random.choice(files))
1842
+ # print('________')
1843
+ # print(files)
1844
+ # print('++++++++')
1845
+ # print(ref_folder)
1846
+ # print("========")
1847
+ # print(name)
1848
+ # print("********")
1849
+ # print(ref_path)
1850
+ img_hair = cv2.imread(hair_path)
1851
+ img_non_hair = cv2.imread(non_hair_path)
1852
+ ref_hair = cv2.imread(ref_path)
1853
+
1854
+
1855
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
1856
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
1857
+
1858
+
1859
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
1860
+ ref_hair = cv2.resize(ref_hair, (512, 512))
1861
+
1862
+
1863
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
1864
+ ref_hair = (ref_hair / 255.0) * 2 - 1
1865
+
1866
+
1867
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
1868
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
1869
+
1870
+ pose = self.pose[random_number1]
1871
+ pose = torch.tensor(pose)
1872
+ # pose2 = self.pose[random_number2]
1873
+ # pose2 = torch.tensor(pose2)
1874
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
1875
+ # hair_num = [0, 2, 6, 14, 18, 21]
1876
+ # img_hair_stack = []
1877
+ # polar = self.polars_rad[random_number1]
1878
+ # polar = torch.tensor(polar).unsqueeze(0)
1879
+ # azimuths = self.azimuths_rad[random_number1]
1880
+ # azimuths = torch.tensor(azimuths).unsqueeze(0)
1881
+ img_hair = cv2.imread(hair_path)
1882
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
1883
+ img_hair = cv2.resize(img_hair, (512, 512))
1884
+ img_hair = (img_hair / 255.0) * 2 - 1
1885
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
1886
+ x = torch.tensor(self.x[random_number1])
1887
+ y = torch.tensor(self.y[random_number1])
1888
+ x2 = torch.tensor(self.x[random_number2])
1889
+ y2 = torch.tensor(self.y[random_number2])
1890
+ # begin = random.randrange(0, 21-self.frame_num)
1891
+ # hair_num = [i+begin for i in range(self.frame_num)]
1892
+ # for i in hair_num:
1893
+ # img_hair = cv2.imread(hair_path)
1894
+ # img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
1895
+ # img_hair = cv2.resize(img_hair, (512, 512))
1896
+ # img_hair = (img_hair / 255.0) * 2 - 1
1897
+ # img_hair = torch.tensor(img_hair).permute(2, 0, 1)
1898
+ # img_hair_stack.append(img_hair)
1899
+ # img_hair = torch.stack(img_hair_stack)
1900
+
1901
+ return {
1902
+ # 'hair_pose': pose1,
1903
+ 'img_hair': img_hair,
1904
+ # 'bald_pose': pose2,
1905
+ 'pose': pose,
1906
+ 'img_non_hair': img_non_hair,
1907
+ 'ref_hair': ref_hair,
1908
+ 'x': x,
1909
+ 'y': y,
1910
+ # 'x2': x2,
1911
+ # 'y2': y2,
1912
+ # 'polar': polar,
1913
+ # 'azimuths':azimuths,
1914
+ }
1915
+
1916
+ def __len__(self):
1917
+ return self.len
1918
+
1919
+ class myDataset_sv3d_simple_random_reference_controlnet(data.Dataset):
1920
+ """Custom data.Dataset compatible with data.DataLoader."""
1921
+
1922
+ def __init__(self, train_data_dir, frame_num=6):
1923
+ self.img_path = os.path.join(train_data_dir, "hair")
1924
+ self.img_path2 = os.path.join(train_data_dir, "hair_good")
1925
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
1926
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
1927
+ self.ref_path = os.path.join(train_data_dir, "multi_reference2")
1928
+ # self.ref_path = os.path.join(train_data_dir, "reference")
1929
+ self.lists = os.listdir(self.img_path)
1930
+ self.len = len(self.lists)-10
1931
+ self.pose = np.load(self.pose_path)
1932
+ self.frame_num = frame_num
1933
+ elevations_deg = [-0.05/2*np.pi*360] * 21
1934
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
1935
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
1936
+ for i in Face_yaws:
1937
+ if i<0:
1938
+ i = 2*np.pi+i
1939
+ i = i/2*np.pi*360
1940
+ face_yaws = [Face_yaws[0]]
1941
+ for i in range(20):
1942
+ face_yaws.append(Face_yaws[3*i+2])
1943
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
1944
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
1945
+ self.azimuths_rad[:-1].sort()
1946
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
1947
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
1948
+ self.x = [x[0]]
1949
+ self.y = [y[0]]
1950
+ for i in range(20):
1951
+ self.x.append(x[i*3+2])
1952
+ self.y.append(y[i*3+2])
1953
+
1954
+ def read_ref_img(self, path):
1955
+ img = cv2.imread(path)
1956
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1957
+ img = hair_transform(image=img)
1958
+ # img = cv2.resize(img, (512, 512))
1959
+ img = (img / 255.0) * 2 - 1
1960
+ img = torch.tensor(img).permute(2, 0, 1)
1961
+ return img
1962
+
1963
+ def __getitem__(self, index):
1964
+ """Returns one data pair (source and target)."""
1965
+ # seq_len, fea_dim
1966
+ random_number1 = random.randrange(0, 21)
1967
+ random_number2 = random.randrange(0, 21)
1968
+ name = self.lists[index].split('.')[0]
1969
+
1970
+ # hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
1971
+ hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
1972
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
1973
+ ref_folder = os.path.join(self.ref_path, name)
1974
+
1975
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
1976
+ ref_path = os.path.join(ref_folder,random.choice(files))
1977
+ img_hair = cv2.imread(hair_path)
1978
+ img_non_hair = cv2.imread(non_hair_path)
1979
+ ref_hair = cv2.imread(ref_path)
1980
+
1981
+
1982
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
1983
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
1984
+
1985
+ ref_hair = hair_transform(image=ref_hair)['image']
1986
+ # print(type(ref_hair))
1987
+ # print(ref_hair.keys())
1988
+ # ref_hair = self.read_ref_img(ref_path)
1989
+
1990
+
1991
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
1992
+ ref_hair = cv2.resize(ref_hair, (512, 512))
1993
+
1994
+
1995
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
1996
+ ref_hair = (ref_hair / 255.0) * 2 - 1
1997
+
1998
+
1999
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
2000
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
2001
+
2002
+ pose = self.pose[random_number1]
2003
+ pose = torch.tensor(pose)
2004
+ #hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
2005
+ # hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
2006
+ img_hair = cv2.imread(hair_path)
2007
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
2008
+ img_hair = cv2.resize(img_hair, (512, 512))
2009
+ img_hair = (img_hair / 255.0) * 2 - 1
2010
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
2011
+ x = torch.tensor(self.x[random_number1])
2012
+ y = torch.tensor(self.y[random_number1])
2013
+ x2 = torch.tensor(self.x[random_number2])
2014
+ y2 = torch.tensor(self.y[random_number2])
2015
+
2016
+ return {
2017
+ # 'hair_pose': pose1,
2018
+ 'img_hair': img_hair,
2019
+ # 'bald_pose': pose2,
2020
+ 'pose': pose,
2021
+ 'img_non_hair': img_non_hair,
2022
+ 'ref_hair': ref_hair,
2023
+ 'x': x,
2024
+ 'y': y,
2025
+ }
2026
+
2027
+ def __len__(self):
2028
+ return self.len
2029
+
2030
+
2031
+ class myDataset_sv3d_simple_random_reference_stable_hair(data.Dataset):
2032
+ """Custom data.Dataset compatible with data.DataLoader."""
2033
+
2034
+ def __init__(self, train_data_dir, frame_num=6):
2035
+ self.img_path = os.path.join(train_data_dir, "hair")
2036
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
2037
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
2038
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
2039
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
2040
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
2041
+ self.ref_path = os.path.join(train_data_dir, "reference")
2042
+ self.lists = os.listdir(self.img_path)
2043
+ self.len = len(self.lists)-10
2044
+ self.pose = np.load(self.pose_path)
2045
+ self.frame_num = frame_num
2046
+ #self.pose = np.random.randn(12, 4)
2047
+ elevations_deg = [-0.05/2*np.pi*360] * 21
2048
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
2049
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
2050
+ for i in Face_yaws:
2051
+ if i<0:
2052
+ i = 2*np.pi+i
2053
+ i = i/2*np.pi*360
2054
+ face_yaws = [Face_yaws[0]]
2055
+ for i in range(20):
2056
+ face_yaws.append(Face_yaws[3*i+2])
2057
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
2058
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
2059
+ self.azimuths_rad[:-1].sort()
2060
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
2061
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
2062
+ self.x = [x[0]]
2063
+ self.y = [y[0]]
2064
+ for i in range(20):
2065
+ self.x.append(x[i*3+2])
2066
+ self.y.append(y[i*3+2])
2067
+
2068
+ def __getitem__(self, index):
2069
+ """Returns one data pair (source and target)."""
2070
+ # seq_len, fea_dim
2071
+ random_number1 = random.randrange(0, 21)
2072
+ random_number2 = random.randrange(0, 21)
2073
+ random_number1 = random_number2
2074
+
2075
+ name = self.lists[index]
2076
+
2077
+
2078
+
2079
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
2080
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
2081
+ ref_folder = os.path.join(self.ref_path, name)
2082
+
2083
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
2084
+ ref_path = os.path.join(ref_folder,random.choice(files))
2085
+ img_hair = cv2.imread(hair_path)
2086
+ img_non_hair = cv2.imread(non_hair_path)
2087
+ ref_hair = cv2.imread(ref_path)
2088
+
2089
+
2090
+ img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
2091
+ ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
2092
+
2093
+
2094
+ img_non_hair = cv2.resize(img_non_hair, (512, 512))
2095
+ ref_hair = cv2.resize(ref_hair, (512, 512))
2096
+
2097
+
2098
+ img_non_hair = (img_non_hair / 255.0) * 2 - 1
2099
+ ref_hair = (ref_hair / 255.0) * 2 - 1
2100
+
2101
+
2102
+ img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
2103
+ ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
2104
+
2105
+ pose = self.pose[random_number1]
2106
+ pose = torch.tensor(pose)
2107
+ hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
2108
+ img_hair = cv2.imread(hair_path)
2109
+ img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
2110
+ img_hair = cv2.resize(img_hair, (512, 512))
2111
+ img_hair = (img_hair / 255.0) * 2 - 1
2112
+ img_hair = torch.tensor(img_hair).permute(2, 0, 1)
2113
+ x = torch.tensor(self.x[random_number1])
2114
+ y = torch.tensor(self.y[random_number1])
2115
+
2116
+ return {
2117
+ 'img_hair': img_hair,
2118
+ 'pose': pose,
2119
+ 'img_non_hair': img_non_hair,
2120
+ 'ref_hair': ref_hair,
2121
+ 'x': x,
2122
+ 'y': y,
2123
+ }
2124
+
2125
+ def __len__(self):
2126
+ return self.len
2127
+
2128
+
2129
+ class myDataset_sv3d_simple_temporal_small_squence(data.Dataset):
2130
+ """Custom data.Dataset compatible with data.DataLoader."""
2131
+
2132
+ def __init__(self, train_data_dir, frame_num=6):
2133
+ self.img_path = os.path.join(train_data_dir, "hair")
2134
+ # self.pose_path = os.path.join(train_data_dir, "pose.npy")
2135
+ # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
2136
+ # self.ref_path = os.path.join(train_data_dir, "ref_hair")
2137
+ self.pose_path = os.path.join(train_data_dir, "pose.npy")
2138
+ self.non_hair_path = os.path.join(train_data_dir, "non-hair")
2139
+ self.ref_path = os.path.join(train_data_dir, "reference")
2140
+ self.lists = os.listdir(self.img_path)
2141
+ self.len = len(self.lists)-10
2142
+ self.pose = np.load(self.pose_path)
2143
+ self.frame_num = frame_num
2144
+ #self.pose = np.random.randn(12, 4)
2145
+ elevations_deg = [-0.05/2*np.pi*360] * 21
2146
+ azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
2147
+ Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
2148
+ for i in Face_yaws:
2149
+ if i<0:
2150
+ i = 2*np.pi+i
2151
+ i = i/2*np.pi*360
2152
+ face_yaws = [Face_yaws[0]]
2153
+ for i in range(20):
2154
+ face_yaws.append(Face_yaws[3*i+2])
2155
+ self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
2156
+ self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
2157
+ self.azimuths_rad[:-1].sort()
2158
+ x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
2159
+ y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
2160
+ self.x = [x[0]]
2161
+ self.y = [y[0]]
2162
+ for i in range(20):
2163
+ self.x.append(x[i*3+2])
2164
+ self.y.append(y[i*3+2])
2165
+
2166
+ def read_img(self, path):
2167
+ img = cv2.imread(path)
2168
+
2169
+ img = cv2.resize(img, (512, 512))
2170
+ img = (img / 255.0) * 2 - 1
2171
+ img = torch.tensor(img).permute(2, 0, 1)
2172
+ return img
2173
+
2174
+ def __getitem__(self, index):
2175
+ """Returns one data pair (source and target)."""
2176
+ # seq_len, fea_dim
2177
+ random_number1 = random.randrange(0, 21-6)
2178
+ random_number2 = random.randrange(0, 21)
2179
+
2180
+ name = self.lists[index]
2181
+ x_stack = []
2182
+ y_stack = []
2183
+ img_non_hair_stack = []
2184
+ img_hair_stack = []
2185
+ non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
2186
+ ref_folder = os.path.join(self.ref_path, name)
2187
+
2188
+ files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
2189
+ # ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
2190
+ ref_path = os.path.join(ref_folder,files[0])
2191
+ ref_hair = self.read_img(ref_path)
2192
+ for i in range(6):
2193
+ img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
2194
+ hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
2195
+ img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
2196
+ x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
2197
+ y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
2198
+
2199
+ img_non_hair = torch.cat(img_non_hair_stack, axis=0)
2200
+ img_hair = torch.cat(img_hair_stack, axis=0)
2201
+ x = torch.cat(x_stack, axis=0)
2202
+ y = torch.cat(y_stack, axis=0)
2203
+
2204
+ return {
2205
+ 'img_hair': img_hair,
2206
+ 'img_non_hair': img_non_hair,
2207
+ 'ref_hair': ref_hair,
2208
+ 'x': x,
2209
+ 'y': y,
2210
+
2211
+ }
2212
+
2213
+ def __len__(self):
2214
+ return self.len
2215
+
2216
+
2217
+ if __name__ == "__main__":
2218
+
2219
+ train_dataset = myDataset("./data")
2220
+ train_dataloader = torch.utils.data.DataLoader(
2221
+ train_dataset,
2222
+ batch_size=1,
2223
+ num_workers=1,
2224
+ )
2225
+
2226
+ for epoch in range(0, len(train_dataset) + 1):
2227
+ for step, batch in enumerate(train_dataloader):
2228
+ print("batch[hair_pose]:", batch["hair_pose"])
2229
+ print("batch[img_hair]:", batch["img_hair"])
2230
+ print("batch[bald_pose]:", batch["bald_pose"])
2231
+ print("batch[img_non_hair]:", batch["img_non_hair"])
2232
+ print("batch[ref_hair]:", batch["ref_hair"])
2233
+
2234
+
2235
+
2236
+
download.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ snapshot_download(
3
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
4
+ local_dir="stable-diffusion-v1-5/stable-diffusion-v1-5",
5
+ local_dir_use_symlinks=False,
6
+ resume_download=True,
7
+ allow_patterns=["unet/*","vae/*","tokenizer/*","scheduler/*","model_index.json"]
8
+ )
gradio_app.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("GRADIO_TEMP_DIR", "/data2/lzliu/tmp/gradio")
3
+ os.environ.setdefault("TMPDIR", "/data2/lzliu/tmp")
4
+ os.makedirs("/data2/lzliu/tmp/gradio", exist_ok=True)
5
+ os.makedirs("/data2/lzliu/tmp", exist_ok=True)
6
+
7
+
8
+ # 其余保持不变
9
+
10
+
11
+ import logging
12
+ import gradio as gr
13
+ import torch
14
+ import os
15
+ import uuid
16
+ from test_stablehairv2 import log_validation
17
+ from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
18
+ from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
19
+ from omegaconf import OmegaConf
20
+ import numpy as np
21
+ import cv2
22
+ from test_stablehairv2 import _maybe_align_image
23
+ from HairMapper.hair_mapper_run import bald_head
24
+
25
+ import base64
26
+
27
+ with open("imgs/background.jpg", "rb") as f:
28
+ b64_img = base64.b64encode(f.read()).decode()
29
+
30
+
31
+ def inference(id_image, hair_image):
32
+ os.makedirs("gradio_inputs", exist_ok=True)
33
+ os.makedirs("gradio_outputs", exist_ok=True)
34
+
35
+ id_path = "gradio_inputs/id.png"
36
+ hair_path = "gradio_inputs/hair.png"
37
+ id_image.save(id_path)
38
+ hair_image.save(hair_path)
39
+
40
+ # ===== 图像对齐 =====
41
+ aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
42
+ aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
43
+
44
+ # 保存对齐结果(方便 Gradio 输出)
45
+ aligned_id_path = "gradio_outputs/aligned_id.png"
46
+ aligned_hair_path = "gradio_outputs/aligned_hair.png"
47
+ cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
48
+ cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
49
+
50
+ # ===== 调用 HairMapper 秃头化 =====
51
+ bald_id_path = "gradio_outputs/bald_id.png"
52
+ cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
53
+ bald_head(bald_id_path, bald_id_path)
54
+
55
+ # ===== 原本的 Args =====
56
+ class Args:
57
+ pretrained_model_name_or_path = "./stable-diffusion-v1-5/stable-diffusion-v1-5"
58
+ model_path = "./trained_model"
59
+ image_encoder = "openai/clip-vit-large-patch14"
60
+ controlnet_model_name_or_path = None
61
+ revision = None
62
+ output_dir = "gradio_outputs"
63
+ seed = 42
64
+ num_validation_images = 1
65
+ validation_ids = [aligned_id_path] # 用对齐后的图像
66
+ validation_hairs = [aligned_hair_path] # 用对齐后的图像
67
+ use_fp16 = False
68
+
69
+ args = Args()
70
+
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+
73
+ # 初始化 logger
74
+ logging.basicConfig(
75
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
76
+ datefmt="%m/%d/%Y %H:%M:%S",
77
+ level=logging.INFO,
78
+ )
79
+ logger = logging.getLogger(__name__)
80
+
81
+ # ===== 模型加载(和 main() 对齐) =====
82
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
83
+ revision=args.revision)
84
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
85
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(
86
+ device, dtype=torch.float32)
87
+
88
+ infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
89
+
90
+ unet2 = UNet2DConditionModel.from_pretrained(
91
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
92
+ ).to(device)
93
+ conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
94
+ padding=unet2.conv_in.padding)
95
+ conv_in_8.requires_grad_(False)
96
+ unet2.conv_in.requires_grad_(False)
97
+ torch.nn.init.zeros_(conv_in_8.weight)
98
+ conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
99
+ conv_in_8.bias.copy_(unet2.conv_in.bias)
100
+ unet2.conv_in = conv_in_8
101
+
102
+ controlnet = ControlNetModel.from_unet(unet2).to(device)
103
+ state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu")
104
+ controlnet.load_state_dict(state_dict2, strict=False)
105
+
106
+ prefix = "motion_module"
107
+ ckpt_num = "4140000"
108
+ save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
109
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
110
+ args.pretrained_model_name_or_path,
111
+ save_path,
112
+ subfolder="unet",
113
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
114
+ ).to(device)
115
+
116
+ cc_projection = CCProjection().to(device)
117
+ state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu")
118
+ cc_projection.load_state_dict(state_dict3, strict=False)
119
+
120
+ from ref_encoder.reference_unet import ref_unet
121
+ Hair_Encoder = ref_unet.from_pretrained(
122
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
123
+ device_map=None, ignore_mismatched_sizes=True
124
+ ).to(device)
125
+ state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
126
+ Hair_Encoder.load_state_dict(state_dict2, strict=False)
127
+
128
+ # 推理
129
+ log_validation(
130
+ vae, tokenizer, image_encoder, denoising_unet,
131
+ args, device, logger,
132
+ cc_projection, controlnet, Hair_Encoder
133
+ )
134
+
135
+ output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
136
+
137
+ # 提取视频帧用于可拖动预览
138
+ frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
139
+ os.makedirs(frames_dir, exist_ok=True)
140
+ cap = cv2.VideoCapture(output_video)
141
+ frames_list = []
142
+ idx = 0
143
+ while True:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+ fp = os.path.join(frames_dir, f"{idx:03d}.png")
148
+ cv2.imwrite(fp, frame)
149
+ frames_list.append(fp)
150
+ idx += 1
151
+ cap.release()
152
+
153
+ max_frames = len(frames_list) if frames_list else 1
154
+ first_frame = frames_list[0] if frames_list else None
155
+
156
+ return aligned_id_path, aligned_hair_path, bald_id_path, output_video, frames_list, gr.update(minimum=1,
157
+ maximum=max_frames,
158
+ value=1,
159
+ step=1), first_frame
160
+
161
+
162
+ # Gradio 前端
163
+ # 原 Interface 版本(保留以便回退)
164
+ # demo = gr.Interface(
165
+ # fn=inference,
166
+ # inputs=[
167
+ # gr.Image(type="pil", label="上传身份图(ID Image)"),
168
+ # gr.Image(type="pil", label="上传发型图(Hair Reference Image)")
169
+ # ],
170
+ # outputs=[
171
+ # gr.Image(type="filepath", label="对齐后的身份图"),
172
+ # gr.Image(type="filepath", label="对齐后的发型图"),
173
+ # gr.Image(type="filepath", label="秃头化后的身份图"),
174
+ # gr.Video(label="生成的视频")
175
+ # ],
176
+ # title="StableHairV2 多视角发型迁移",
177
+ # description="上传身份图和发型参考图,查看对齐结果并生成多视角视频"
178
+ # )
179
+ # if __name__ == "__main__":
180
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
181
+
182
+ # Blocks 美化版
183
+ css = f"""
184
+ html, body {{
185
+ height: 100%;
186
+ margin: 0;
187
+ padding: 0;
188
+ }}
189
+ .gradio-container {{
190
+ width: 100% !important;
191
+ height: 100% !important;
192
+ margin: 0 !important;
193
+ padding: 0 !important;
194
+ background-image: url("data:image/jpeg;base64,{b64_img}");
195
+ background-size: cover;
196
+ background-position: center;
197
+ background-attachment: fixed; /* 背景固定 */
198
+ }}
199
+ #title-card {{
200
+ background: rgba(255, 255, 255, 0.8);
201
+ border-radius: 12px;
202
+ padding: 16px 24px;
203
+ box-shadow: 0 2px 8px rgba(0,0,0,0.15);
204
+ margin-bottom: 20px;
205
+ }}
206
+ #title-card h2 {{
207
+ text-align: center;
208
+ margin: 4px 0 12px 0;
209
+ font-size: 28px;
210
+ }}
211
+ #title-card p {{
212
+ text-align: center;
213
+ font-size: 16px;
214
+ color: #374151;
215
+ }}
216
+ .out-card {{
217
+ border:1px solid #e5e7eb; border-radius:10px; padding:10px;
218
+ background: rgba(255,255,255,0.85);
219
+ }}
220
+ .two-col {{
221
+ display:grid !important; grid-template-columns: 360px minmax(680px, 1fr); gap:16px
222
+ }}
223
+ .left-pane {{min-width: 360px}}
224
+ .right-pane {{min-width: 680px}}
225
+ /* Tabs 美化 */
226
+ .tabs {{
227
+ background: rgba(255,255,255,0.88);
228
+ border-radius: 12px;
229
+ box-shadow: 0 8px 24px rgba(0,0,0,0.08);
230
+ padding: 8px;
231
+ border: 1px solid #e5e7eb;
232
+ }}
233
+ .tab-nav {{
234
+ display: flex; gap: 8px; margin-bottom: 8px;
235
+ background: transparent;
236
+ border-bottom: 1px solid #e5e7eb;
237
+ padding-bottom: 6px;
238
+ }}
239
+ .tab-nav button {{
240
+ background: rgba(255,255,255,0.7);
241
+ border: 1px solid #e5e7eb;
242
+ backdrop-filter: blur(6px);
243
+ border-radius: 8px;
244
+ padding: 6px 12px;
245
+ color: #111827;
246
+ transition: all .2s ease;
247
+ }}
248
+ .tab-nav button:hover {{
249
+ transform: translateY(-1px);
250
+ box-shadow: 0 4px 10px rgba(0,0,0,0.06);
251
+ }}
252
+ .tab-nav button[aria-selected="true"] {{
253
+ background: #4f46e5;
254
+ color: #fff;
255
+ border-color: #4f46e5;
256
+ box-shadow: 0 6px 14px rgba(79,70,229,0.25);
257
+ }}
258
+ .tabitem {{
259
+ background: rgba(255,255,255,0.88);
260
+ border-radius: 10px;
261
+ padding: 8px;
262
+ }}
263
+ /* 发型库滚动限制容器:固定260px高度,内部可滚动 */
264
+ #hair_gallery_wrap {{
265
+ height: 260px !important;
266
+ overflow-y: scroll !important;
267
+ overflow-x: auto !important;
268
+ }}
269
+ #hair_gallery_wrap .grid, #hair_gallery_wrap .wrap {{
270
+ height: 100% !important;
271
+ overflow-y: scroll !important;
272
+ }}
273
+ /* 确保画廊本体占满容���高度,避免滚动条落到页面底部 */
274
+ #hair_gallery {{
275
+ height: 100% !important;
276
+ }}
277
+ """
278
+
279
+ with gr.Blocks(
280
+ theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
281
+ css=css
282
+ ) as demo:
283
+ # ==== 顶部 Panel ====
284
+ with gr.Group(elem_id="title-card"):
285
+ gr.Markdown("""
286
+ <h2 id='title'>StableHairV2 多视角发型迁移</h2>
287
+ <p>上传身份图与发型参考图,系统将自动完成 <b>对齐 → 秃头化 → 视频生成</b>。</p>
288
+ """)
289
+
290
+ with gr.Row(elem_classes=["two-col"]):
291
+ with gr.Column(scale=5, min_width=260, elem_classes=["left-pane"]):
292
+ id_input = gr.Image(type="pil", label="身份图", height=200)
293
+ hair_input = gr.Image(type="pil", label="发型参考图", height=200)
294
+
295
+ with gr.Row():
296
+ run_btn = gr.Button("开始生成", variant="primary")
297
+ clear_btn = gr.Button("清空")
298
+
299
+ # ========= 发型库(点击即填充到“发型参考图”) =========
300
+ def _list_imgs(dir_path: str):
301
+ exts = (".png", ".jpg", ".jpeg", ".webp")
302
+ # exts = (".jpg")
303
+ try:
304
+ files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path))
305
+ if f.lower().endswith(exts)]
306
+ return files
307
+ except Exception:
308
+ return []
309
+
310
+ hair_list = _list_imgs("hair_resposity")
311
+
312
+ with gr.Accordion("发型库(点击选择后自动填充)", open=True):
313
+ with gr.Group(elem_id="hair_gallery_wrap"):
314
+ gallery = gr.Gallery(
315
+ value=hair_list,
316
+ columns=4, rows=2, allow_preview=True, label="发型库",
317
+ elem_id="hair_gallery"
318
+ )
319
+
320
+ def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
321
+ i = evt.index if hasattr(evt, 'index') else 0
322
+ i = 0 if i is None else int(i)
323
+ if 0 <= i < len(hair_list):
324
+ return gr.update(value=hair_list[i])
325
+ return gr.update()
326
+
327
+ gallery.select(_pick_hair, inputs=None, outputs=hair_input)
328
+
329
+ with gr.Column(scale=7, min_width=520, elem_classes=["right-pane"]):
330
+ with gr.Tabs():
331
+ with gr.TabItem("生成视频"):
332
+ with gr.Group(elem_classes=["out-card"]):
333
+ video_out = gr.Video(label="生成的视频", height=340)
334
+ with gr.Row():
335
+ frame_slider = gr.Slider(1, 21, value=1, step=1, label="多视角预览(拖动查看帧)")
336
+ frame_preview = gr.Image(type="filepath", label="预览帧", height=260)
337
+ frames_state = gr.State([])
338
+
339
+ with gr.TabItem("归一化对齐结果"):
340
+ with gr.Group(elem_classes=["out-card"]):
341
+ with gr.Row():
342
+ aligned_id_out = gr.Image(type="filepath", label="对齐后的身份图", height=240)
343
+ aligned_hair_out = gr.Image(type="filepath", label="对齐后的发型图", height=240)
344
+
345
+ with gr.TabItem("秃头化结果"):
346
+ with gr.Group(elem_classes=["out-card"]):
347
+ bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
348
+
349
+ # 逻辑保持不变
350
+ run_btn.click(fn=inference,
351
+ inputs=[id_input, hair_input],
352
+ outputs=[aligned_id_out, aligned_hair_out, bald_id_out,
353
+ video_out, frames_state, frame_slider, frame_preview])
354
+
355
+
356
+ def _on_slide(frames, idx):
357
+ if not frames:
358
+ return gr.update()
359
+ i = int(idx) - 1
360
+ i = max(0, min(i, len(frames) - 1))
361
+ return gr.update(value=frames[i])
362
+
363
+
364
+ frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
365
+
366
+
367
+ def _clear():
368
+ return None, None, None, None, None
369
+
370
+
371
+ clear_btn.click(_clear, None,
372
+ [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
373
+
374
+ if __name__ == "__main__":
375
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
376
+
377
+
378
+
379
+
requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ albucore==0.0.24
3
+ albumentations==2.0.8
4
+ annotated-types==0.7.0
5
+ antlr4-python3-runtime==4.9.3
6
+ av==14.4.0
7
+ Brotli
8
+ certifi
9
+ charset-normalizer
10
+ click==8.2.1
11
+ colorama
12
+ einops==0.8.1
13
+ filelock
14
+ fsspec
15
+ gitdb==4.0.12
16
+ GitPython==3.1.44
17
+ gmpy2
18
+ graphviz==0.20.3
19
+ hf-xet==1.1.5
20
+ huggingface-hub==0.30.0
21
+ idna
22
+ Jinja2
23
+ kornia
24
+ MarkupSafe
25
+ mkl-service==2.4.0
26
+ mkl_fft
27
+ mkl_random
28
+ mpmath
29
+ networkx
30
+ numpy
31
+ omegaconf==2.3.0
32
+ opencv-python==4.11.0.86
33
+ opencv-python-headless==4.11.0.86
34
+ packaging
35
+ peft==0.15.2
36
+ pillow
37
+ platformdirs==4.3.8
38
+ prodigyopt==1.1.2
39
+ protobuf==6.31.1
40
+ psutil
41
+ pydantic==2.11.7
42
+ pydantic_core==2.33.2
43
+ PySocks
44
+ PyYAML
45
+ regex==2024.11.6
46
+ requests
47
+ safetensors
48
+ scipy==1.15.3
49
+ sentencepiece==0.2.0
50
+ sentry-sdk==2.32.0
51
+ setproctitle==1.3.6
52
+ simsimd==6.4.9
53
+ smmap==5.0.2
54
+ stringzilla==3.12.5
55
+ sympy==1.13.1
56
+ tokenizers==0.21.1
57
+ torch==2.5.0
58
+ torchaudio==2.5.0
59
+ torchvision==0.20.0
60
+ torchviz==0.0.3
61
+ tqdm
62
+ transformers==4.52.3
63
+ triton==3.1.0
64
+ typing-inspection==0.4.1
65
+ typing_extensions
66
+ urllib3
67
+ wandb==0.20.1
test_stablehairv2.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import logging
4
+ import sys
5
+ import os
6
+ import random
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ from PIL import Image
11
+ from transformers import AutoTokenizer, CLIPVisionModelWithProjection
12
+ from diffusers import AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
13
+ from src.models.unet_3d import UNet3DConditionModel
14
+ from ref_encoder.reference_unet import CCProjection
15
+ from ref_encoder.latent_controlnet import ControlNetModel
16
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline as Hair3dPipeline
17
+ from src.utils.util import save_videos_grid
18
+ from omegaconf import OmegaConf
19
+ from HairMapper.hair_mapper_run import bald_head
20
+
21
+
22
+ # face align
23
+ def _maybe_align_image(image_path: str, output_size: int, prefer_cuda: bool = True):
24
+ """Align and crop a face image to FFHQ-style using FFHQFaceAlignment if available.
25
+ Falls back to simple resize if alignment fails.
26
+ Returns an RGB uint8 numpy array of shape (H, W, 3).
27
+ """
28
+ try:
29
+ ffhq_dir = os.path.join(os.path.dirname(__file__), 'FFHQFaceAlignment')
30
+ if ffhq_dir not in sys.path:
31
+ sys.path.insert(0, ffhq_dir)
32
+ # Lazy imports to avoid hard dependency if user doesn't enable alignment
33
+ from lib.landmarks_pytorch import LandmarksEstimation
34
+ from align import align_crop_image
35
+
36
+ # Read image as RGB uint8
37
+ img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
38
+ if img_bgr is None:
39
+ raise RuntimeError(f"Failed to read image: {image_path}")
40
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8')
41
+
42
+ device = torch.device('cuda' if prefer_cuda and torch.cuda.is_available() else 'cpu')
43
+ le = LandmarksEstimation(type='2D')
44
+
45
+ img_tensor = torch.tensor(np.transpose(img, (2, 0, 1))).float().to(device)
46
+ with torch.no_grad():
47
+ landmarks, _ = le.detect_landmarks(img_tensor.unsqueeze(0), detected_faces=None)
48
+ if len(landmarks) > 0:
49
+ lm = np.asarray(landmarks[0].detach().cpu().numpy())
50
+ aligned = align_crop_image(image=img, landmarks=lm, transform_size=output_size)
51
+ if aligned is None or aligned.size == 0:
52
+ return cv2.resize(img, (output_size, output_size))
53
+ return aligned
54
+ else:
55
+ return cv2.resize(img, (output_size, output_size))
56
+ except Exception:
57
+ # Silent fallback to simple resize on any failure
58
+ img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
59
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8') if img_bgr is not None else None
60
+ if img is None:
61
+ raise
62
+ return cv2.resize(img, (output_size, output_size))
63
+
64
+
65
+ def log_validation(
66
+ vae, tokenizer, image_encoder, denoising_unet,
67
+ args, device, logger, cc_projection,
68
+ controlnet, hair_encoder, feature_extractor=None
69
+ ):
70
+ """
71
+ Run inference on validation pairs and save generated videos.
72
+ """
73
+ logger.info("Starting validation inference...")
74
+
75
+ # Initialize inference pipeline
76
+ pipeline = Hair3dPipeline.from_pretrained(
77
+ args.pretrained_model_name_or_path,
78
+ image_encoder=image_encoder,
79
+ feature_extractor=feature_extractor,
80
+ controlnet=controlnet,
81
+ vae=vae,
82
+ tokenizer=tokenizer,
83
+ denoising_unet=denoising_unet,
84
+ safety_checker=None,
85
+ revision=args.revision,
86
+ torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
87
+ ).to(device)
88
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
89
+ pipeline.set_progress_bar_config(disable=True)
90
+
91
+ # Create output directory
92
+ output_dir = os.path.join(args.output_dir, "validation")
93
+ os.makedirs(output_dir, exist_ok=True)
94
+
95
+ print(output_dir)
96
+
97
+ # Generate camera trajectory
98
+ x_coords = [0.4 * np.sin(2 * np.pi * i / 120) for i in range(60)]
99
+ y_coords = [-0.05 + 0.3 * np.cos(2 * np.pi * i / 120) for i in range(60)]
100
+ X = [x_coords[0]]
101
+ Y = [y_coords[0]]
102
+ for i in range(20):
103
+ X.append(x_coords[i * 3 + 2])
104
+ Y.append(y_coords[i * 3 + 2])
105
+ x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
106
+ y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
107
+
108
+ # # Load reference images
109
+ # id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
110
+ # id_image = cv2.resize(id_image, (512, 512))
111
+ # Load reference images (optionally align)
112
+ align_enabled = getattr(args, 'align_before_infer', True)
113
+ align_size = getattr(args, 'align_size', 1024)
114
+ prefer_cuda = True if device.type == 'cuda' else False
115
+ if align_enabled:
116
+ id_image = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
117
+ else:
118
+ id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
119
+ id_image = cv2.resize(id_image, (512, 512))
120
+
121
+ # ===== ���� HairMapper ͺͷ�� =====
122
+ temp_bald_path = os.path.join(args.output_dir, "bald_id.png")
123
+ cv2.imwrite(temp_bald_path, cv2.cvtColor(id_image, cv2.COLOR_RGB2BGR)) # �������ͼ
124
+ bald_head(temp_bald_path, temp_bald_path) # ͺͷ�������DZ���
125
+ # ���¼���ͺͷͼ�� (RGB)
126
+ id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
127
+ id_image = cv2.resize(id_image, (512, 512))
128
+
129
+ id_list = [id_image for _ in range(12)]
130
+ if align_enabled:
131
+ hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
132
+ prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
133
+ else:
134
+ hair_image = cv2.cvtColor(cv2.imread(args.validation_hairs[0]), cv2.COLOR_BGR2RGB)
135
+ hair_image = cv2.resize(hair_image, (512, 512))
136
+ prompt_img = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
137
+ prompt_img = cv2.resize(prompt_img, (512, 512))
138
+ hair_image = cv2.resize(hair_image, (512, 512))
139
+ prompt_img = cv2.resize(prompt_img, (512, 512))
140
+
141
+ prompt_img = [prompt_img]
142
+
143
+ # Perform inference and save videos
144
+ for idx in range(args.num_validation_images):
145
+ result = pipeline(
146
+ prompt="",
147
+ negative_prompt="",
148
+ num_inference_steps=30,
149
+ guidance_scale=1.5,
150
+ width=512,
151
+ height=512,
152
+ controlnet_condition=id_list,
153
+ controlnet_conditioning_scale=1.0,
154
+ generator=torch.Generator(device).manual_seed(args.seed),
155
+ ref_image=hair_image,
156
+ prompt_img=prompt_img,
157
+ reference_encoder=hair_encoder,
158
+ poses=None,
159
+ x=x_tensor,
160
+ y=y_tensor,
161
+ video_length=21,
162
+ context_frames=12,
163
+ )
164
+ video = torch.cat([result.videos, result.videos], dim=0)
165
+ video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
166
+ save_videos_grid(video, video_path, n_rows=5, fps=24)
167
+ logger.info(f"Saved generated video: {video_path}")
168
+
169
+
170
+ def parse_args():
171
+ parser = argparse.ArgumentParser(
172
+ description="Inference script for 3D hairstyle generation"
173
+ )
174
+ parser.add_argument(
175
+ "--pretrained_model_name_or_path", type=str, required=True,
176
+ help="Path or ID of the pretrained pipeline"
177
+ )
178
+ parser.add_argument(
179
+ "--model_path", type=str, required=True,
180
+ help="Path or ID of the pretrained pipeline"
181
+ )
182
+ parser.add_argument(
183
+ "--image_encoder", type=str, required=True,
184
+ help="Path or ID of the CLIP vision encoder"
185
+ )
186
+ parser.add_argument(
187
+ "--controlnet_model_name_or_path", type=str, default=None,
188
+ help="Path or ID of the ControlNet model"
189
+ )
190
+ parser.add_argument(
191
+ "--revision", type=str, default=None,
192
+ help="Model revision or Git reference"
193
+ )
194
+ parser.add_argument(
195
+ "--output_dir", type=str, default="inference_output",
196
+ help="Directory to save inference results"
197
+ )
198
+ parser.add_argument(
199
+ "--seed", type=int, default=42,
200
+ help="Random seed for reproducibility"
201
+ )
202
+ parser.add_argument(
203
+ "--num_validation_images", type=int, default=3,
204
+ help="Number of videos to generate per input pair"
205
+ )
206
+ parser.add_argument(
207
+ "--validation_ids", type=str, nargs='+', required=True,
208
+ help="Path(s) to identity conditioning images"
209
+ )
210
+ parser.add_argument(
211
+ "--validation_hairs", type=str, nargs='+', required=True,
212
+ help="Path(s) to hairstyle reference images"
213
+ )
214
+ parser.add_argument(
215
+ "--use_fp16", action="store_true",
216
+ help="Enable fp16 inference"
217
+ )
218
+ parser.add_argument(
219
+ "--align_before_infer", action="store_true", default=True,
220
+ help="Align and crop input images to FFHQ style before inference"
221
+ )
222
+ parser.add_argument(
223
+ "--align_size", type=int, default=1024,
224
+ help="Output size for aligned images when alignment is enabled"
225
+ )
226
+ return parser.parse_args()
227
+
228
+
229
+ def main():
230
+ args = parse_args()
231
+ # Setup device and logger
232
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
+ logging.basicConfig(
234
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
235
+ datefmt="%m/%d/%Y %H:%M:%S",
236
+ level=logging.INFO,
237
+ )
238
+ logger = logging.getLogger(__name__)
239
+
240
+ # Set random seed
241
+ torch.manual_seed(args.seed)
242
+ if device.type == "cuda":
243
+ torch.cuda.manual_seed_all(args.seed)
244
+
245
+ # Load models
246
+ tokenizer = AutoTokenizer.from_pretrained(
247
+ args.pretrained_model_name_or_path,
248
+ subfolder="tokenizer",
249
+ revision=args.revision
250
+ )
251
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
252
+ args.image_encoder,
253
+ revision=args.revision
254
+ ).to(device)
255
+ vae = AutoencoderKL.from_pretrained(
256
+ args.pretrained_model_name_or_path,
257
+ subfolder="vae",
258
+ revision=args.revision
259
+ ).to(device)
260
+
261
+ infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
262
+
263
+ unet2 = UNet2DConditionModel.from_pretrained(
264
+ args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision,
265
+ torch_dtype=torch.float16
266
+ ).to(device)
267
+ conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
268
+ padding=unet2.conv_in.padding)
269
+ conv_in_8.requires_grad_(False)
270
+ unet2.conv_in.requires_grad_(False)
271
+ torch.nn.init.zeros_(conv_in_8.weight)
272
+ conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
273
+ conv_in_8.bias.copy_(unet2.conv_in.bias)
274
+ unet2.conv_in = conv_in_8
275
+
276
+ # Load or initialize ControlNet
277
+ controlnet = ControlNetModel.from_unet(unet2).to(device)
278
+ # state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu'))
279
+ # state_dict2 = torch.load(args.model_path, map_location=torch.device('cpu'))
280
+ state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu'))
281
+ controlnet.load_state_dict(state_dict2, strict=False)
282
+
283
+ # Load 3D UNet motion module
284
+ prefix = "motion_module"
285
+ ckpt_num = "4140000"
286
+ save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
287
+
288
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
289
+ args.pretrained_model_name_or_path,
290
+ save_path,
291
+ subfolder="unet",
292
+
293
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
294
+ ).to(device)
295
+
296
+ # Load projection and hair encoder
297
+ cc_projection = CCProjection().to(device)
298
+ state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location=torch.device('cpu'))
299
+ cc_projection.load_state_dict(state_dict3, strict=False)
300
+
301
+ from ref_encoder.reference_unet import ref_unet
302
+ Hair_Encoder = ref_unet.from_pretrained(
303
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
304
+ device_map=None, ignore_mismatched_sizes=True
305
+ ).to(device)
306
+
307
+ state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
308
+ # state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin"))
309
+ Hair_Encoder.load_state_dict(state_dict2, strict=False)
310
+
311
+ # Run validation inference
312
+ log_validation(
313
+ vae, tokenizer, image_encoder, denoising_unet,
314
+ args, device, logger,
315
+ cc_projection, controlnet, Hair_Encoder
316
+ )
317
+
318
+
319
+ if __name__ == "__main__":
320
+ main()