Spaces:
Running
on
Zero
Running
on
Zero
Upload 6 files
Browse files- README.md +51 -14
- dataset_mv.py +2236 -0
- download.py +8 -0
- gradio_app.py +379 -0
- requirements.txt +67 -0
- test_stablehairv2.py +320 -0
README.md
CHANGED
@@ -1,14 +1,51 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# StableHair v2
|
2 |
+
**Stable-Hair v2: Real-World Hair Transfer via Multiple-View Diffusion Model**
|
3 |
+
Kuiyuan Sun*, [Yuxuan Zhang*](https://xiaojiu-z.github.io/YuxuanZhang.github.io/), [Jichao Zhang*](https://zhangqianhui.github.io/), [Jiaming Liu](https://scholar.google.com/citations?user=SmL7oMQAAAAJ&hl=en),
|
4 |
+
[Wei Wang](https://weiwangtrento.github.io/), [Nicu Sebe](http://disi.unitn.it/~sebe/), [Yao Zhao](https://scholar.google.com/citations?user=474TbQYAAAAJ&hl=en&oi=ao)<br>
|
5 |
+
*Equal Contribution <br>
|
6 |
+
Beijing Jiaotong University, Shanghai Jiaotong University, Ocean University of China, Tiamat AI, University of Trento <br>
|
7 |
+
[Arxiv](https://arxiv.org/abs/2507.07591), [Project]()<br>
|
8 |
+
|
9 |
+
|
10 |
+
Bald | Reference | Multiple View | Original Video
|
11 |
+

|
12 |
+
Bald | Reference | Multiple View | Original Video
|
13 |
+

|
14 |
+
|
15 |
+
## Environments
|
16 |
+
|
17 |
+
```
|
18 |
+
conda create -n stablehairv2 python=3.10
|
19 |
+
```
|
20 |
+
```
|
21 |
+
pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
|
24 |
+
## Results
|
25 |
+
|
26 |
+
<img src="./imgs/teaser.jpg" width="800">
|
27 |
+
|
28 |
+
## Pretrained Model
|
29 |
+
| Name | Model |
|
30 |
+
|----------------------------|:---------:|
|
31 |
+
| motion_module-41400000.pth | [:link:](https://drive.google.com/file/d/1AZMhui9jNRF3Z0N72VDPOwDd0JafLQ3B/view?usp=drive_link) |
|
32 |
+
| pytorch_model_1.bin | [:link:](https://drive.google.com/file/d/1FwKPZI8lvdlZqu8R1aJ-QbE55kxHPHjU/view?usp=drive_link) |
|
33 |
+
| pytorch_model_2.bin | [:link:](https://drive.google.com/file/d/1h3dXlo8lhZN3ee5aN0shZmpLfn5itVou/view?usp=drive_link) |
|
34 |
+
| pytorch_model_3.bin | [:link:](https://drive.google.com/file/d/1jARfXaU6wiur85Vm1JxZ_xye0FfrUiqb/view?usp=drive_link) |
|
35 |
+
| pytorch_model.bin | [:link:](https://drive.google.com/file/d/1zXXf13pV5IOn2vrV6DGI9hliEFvuPrYf/view?usp=drive_link) |
|
36 |
+
|
37 |
+
### Multiple View Hair Transfer
|
38 |
+
|
39 |
+
Please use ``gdown''' to download the pretrained model and save it in your model_path
|
40 |
+
```
|
41 |
+
|
42 |
+
python test_stablehairv2.py --pretrained_model_name_or_path "stable-diffusion-v1-5/stable-diffusion-v1-5" \
|
43 |
+
--image_encoder "openai/clip-vit-large-patch14" --output_dir [Your_output_dir] \
|
44 |
+
--num_validation_images 1 --validation_ids ./test_imgs/bald.jpg \
|
45 |
+
--validation_hairs ./test_imgs/ref1.jpg --model_path [Your_model_path]
|
46 |
+
```
|
47 |
+
|
48 |
+
|
49 |
+
# Our V1 version
|
50 |
+
|
51 |
+
StableHair v2 is an improved version of [StableHair](https://github.com/Xiaojiu-z/Stable-Hair) (AAAI 2025)
|
dataset_mv.py
ADDED
@@ -0,0 +1,2236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils import data
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import random
|
7 |
+
import albumentations as A
|
8 |
+
|
9 |
+
pixel_transform = A.Compose([
|
10 |
+
A.SmallestMaxSize(max_size=512),
|
11 |
+
A.CenterCrop(512, 512),
|
12 |
+
A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
|
13 |
+
], additional_targets={'image0': 'image', 'image1': 'image'})
|
14 |
+
|
15 |
+
hair_transform = A.Compose([
|
16 |
+
A.SmallestMaxSize(max_size=512),
|
17 |
+
A.CenterCrop(512, 512),
|
18 |
+
A.Affine(scale=(0.9, 1.2), rotate=(-10, 10), p=0.7)]
|
19 |
+
)
|
20 |
+
|
21 |
+
# class myDataset(data.Dataset):
|
22 |
+
# """Custom data.Dataset compatible with data.DataLoader."""
|
23 |
+
|
24 |
+
# def __init__(self, train_data_dir):
|
25 |
+
# self.img_path = os.path.join(train_data_dir, "hair")
|
26 |
+
# # self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
27 |
+
# # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
28 |
+
# # self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
29 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
30 |
+
# self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
31 |
+
# self.ref_path = os.path.join(train_data_dir, "reference")
|
32 |
+
# self.lists = os.listdir(self.img_path)
|
33 |
+
# self.len = len(self.lists)-10
|
34 |
+
# self.pose = np.load(self.pose_path)
|
35 |
+
# #self.pose = np.random.randn(12, 4)
|
36 |
+
|
37 |
+
# def __getitem__(self, index):
|
38 |
+
# """Returns one data pair (source and target)."""
|
39 |
+
# # seq_len, fea_dim
|
40 |
+
# random_number1 = random.randrange(0, 21)
|
41 |
+
# random_number2 = random.randrange(0, 21)
|
42 |
+
|
43 |
+
# while random_number2 == random_number1:
|
44 |
+
# random_number2 = random.randrange(0, 21)
|
45 |
+
# name = self.lists[index]
|
46 |
+
|
47 |
+
# random_number1 = random_number1
|
48 |
+
# #* 10
|
49 |
+
# #random_number2 = random_number2 * 10
|
50 |
+
|
51 |
+
# random_number2 = random_number1
|
52 |
+
|
53 |
+
|
54 |
+
# non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
55 |
+
# ref_folder = os.path.join(self.ref_path, name)
|
56 |
+
|
57 |
+
# files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
58 |
+
# ref_path = os.path.join(ref_folder, files[0])
|
59 |
+
|
60 |
+
# img_non_hair = cv2.imread(non_hair_path)
|
61 |
+
# ref_hair = cv2.imread(ref_path)
|
62 |
+
|
63 |
+
|
64 |
+
# img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
65 |
+
# ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
66 |
+
|
67 |
+
|
68 |
+
# img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
69 |
+
# ref_hair = cv2.resize(ref_hair, (512, 512))
|
70 |
+
|
71 |
+
|
72 |
+
# img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
73 |
+
# ref_hair = (ref_hair / 255.0) * 2 - 1
|
74 |
+
|
75 |
+
|
76 |
+
# img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
77 |
+
# ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
78 |
+
|
79 |
+
# pose1 = self.pose[random_number1]
|
80 |
+
# pose1 = torch.tensor(pose1)
|
81 |
+
# pose2 = self.pose[random_number2]
|
82 |
+
# pose2 = torch.tensor(pose2)
|
83 |
+
# hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
84 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
85 |
+
# img_hair_stack = []
|
86 |
+
# for i in hair_num:
|
87 |
+
# img_hair = cv2.imread(hair_path)
|
88 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
89 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
90 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
91 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
92 |
+
# img_hair_stack.append(img_hair)
|
93 |
+
# img_hair = torch.stack(img_hair_stack)
|
94 |
+
|
95 |
+
# return {
|
96 |
+
# 'hair_pose': pose1,
|
97 |
+
# 'img_hair': img_hair,
|
98 |
+
# 'bald_pose': pose2,
|
99 |
+
# 'img_non_hair': img_non_hair,
|
100 |
+
# 'ref_hair': ref_hair
|
101 |
+
# }
|
102 |
+
|
103 |
+
# def __len__(self):
|
104 |
+
# return self.len
|
105 |
+
|
106 |
+
class myDataset(data.Dataset):
|
107 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
108 |
+
def __init__(self, train_data_dir):
|
109 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
110 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
111 |
+
self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
112 |
+
self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
113 |
+
|
114 |
+
self.lists = os.listdir(self.img_path)
|
115 |
+
self.len = len(self.lists)
|
116 |
+
self.pose = np.load(self.pose_path)
|
117 |
+
|
118 |
+
def __getitem__(self, index):
|
119 |
+
"""Returns one data pair (source and target)."""
|
120 |
+
# seq_len, fea_dim
|
121 |
+
random_number1 = random.randrange(0, 120)
|
122 |
+
random_number2 = random.randrange(0, 120)
|
123 |
+
while random_number2==random_number1:
|
124 |
+
random_number2 = random.randrange(0, 120)
|
125 |
+
name = self.lists[index]
|
126 |
+
|
127 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1)+'.jpg')
|
128 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2)+'.jpg')
|
129 |
+
ref_folder = os.path.join(self.ref_path, name)
|
130 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
131 |
+
ref_path = os.path.join(ref_folder, files[0])
|
132 |
+
img_hair = cv2.imread(hair_path)
|
133 |
+
img_non_hair = cv2.imread(non_hair_path)
|
134 |
+
ref_hair = cv2.imread(ref_path)
|
135 |
+
|
136 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
137 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
138 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
139 |
+
|
140 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
141 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
142 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
143 |
+
img_hair = (img_hair/255.0)* 2 - 1
|
144 |
+
img_non_hair = (img_non_hair/255.0)
|
145 |
+
ref_hair = (ref_hair/255.0)* 2 - 1
|
146 |
+
|
147 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
148 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
149 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
150 |
+
|
151 |
+
pose1 = self.pose[random_number1]
|
152 |
+
pose1 = torch.tensor(pose1)
|
153 |
+
pose2 = self.pose[random_number2]
|
154 |
+
pose2 = torch.tensor(pose2)
|
155 |
+
|
156 |
+
return {
|
157 |
+
'hair_pose': pose1,
|
158 |
+
'img_hair':img_hair,
|
159 |
+
'bald_pose': pose2,
|
160 |
+
'img_non_hair':img_non_hair,
|
161 |
+
'ref_hair':ref_hair
|
162 |
+
}
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return self.len
|
166 |
+
|
167 |
+
|
168 |
+
# class myDataset_unet(data.Dataset):
|
169 |
+
# """Custom data.Dataset compatible with data.DataLoader."""
|
170 |
+
|
171 |
+
# class myDataset_unet(data.Dataset):
|
172 |
+
# """Custom data.Dataset compatible with data.DataLoader."""
|
173 |
+
|
174 |
+
# def __init__(self, train_data_dir, frame_num=6):
|
175 |
+
# self.img_path = os.path.join(train_data_dir, "hair")
|
176 |
+
# # self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
177 |
+
# # self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
178 |
+
# # self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
179 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
180 |
+
# self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
181 |
+
# self.ref_path = os.path.join(train_data_dir, "reference")
|
182 |
+
# self.lists = os.listdir(self.img_path)
|
183 |
+
# self.len = len(self.lists)-10
|
184 |
+
# self.pose = np.load(self.pose_path)
|
185 |
+
# self.frame_num = frame_num
|
186 |
+
# #self.pose = np.random.randn(12, 4)
|
187 |
+
|
188 |
+
# def __getitem__(self, index):
|
189 |
+
# """Returns one data pair (source and target)."""
|
190 |
+
# # seq_len, fea_dim
|
191 |
+
# random_number1 = random.randrange(0, 21)
|
192 |
+
# random_number2 = random.randrange(0, 21)
|
193 |
+
|
194 |
+
# while random_number2 == random_number1:
|
195 |
+
# random_number2 = random.randrange(0, 21)
|
196 |
+
# name = self.lists[index]
|
197 |
+
|
198 |
+
# random_number1 = random_number1
|
199 |
+
# #* 10
|
200 |
+
# #random_number2 = random_number2 * 10
|
201 |
+
|
202 |
+
# random_number2 = random_number1
|
203 |
+
|
204 |
+
|
205 |
+
# non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
206 |
+
# ref_folder = os.path.join(self.ref_path, name)
|
207 |
+
# ref_folder = os.path.join(self.img_path, name)
|
208 |
+
|
209 |
+
# files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
210 |
+
# #ref_path = os.path.join(ref_folder, files[0])
|
211 |
+
# ref_path = os.path.join(ref_folder, '0.jpg')
|
212 |
+
|
213 |
+
# img_non_hair = cv2.imread(non_hair_path)
|
214 |
+
# ref_hair = cv2.imread(ref_path)
|
215 |
+
|
216 |
+
|
217 |
+
# img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
218 |
+
# ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
219 |
+
|
220 |
+
|
221 |
+
# img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
222 |
+
# ref_hair = cv2.resize(ref_hair, (512, 512))
|
223 |
+
|
224 |
+
|
225 |
+
# img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
226 |
+
# ref_hair = (ref_hair / 255.0) * 2 - 1
|
227 |
+
|
228 |
+
|
229 |
+
# img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
230 |
+
# ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
231 |
+
|
232 |
+
# pose1 = self.pose[random_number1]
|
233 |
+
# pose1 = torch.tensor(pose1)
|
234 |
+
# pose2 = self.pose[random_number2]
|
235 |
+
# pose2 = torch.tensor(pose2)
|
236 |
+
# hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
237 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
238 |
+
# img_hair_stack = []
|
239 |
+
# # begin = random.randrange(0, 21-self.frame_num)
|
240 |
+
# # hair_num = [i+begin for i in range(self.frame_num)]
|
241 |
+
# for i in hair_num:
|
242 |
+
# img_hair = cv2.imread(hair_path)
|
243 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
244 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
245 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
246 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
247 |
+
# img_hair_stack.append(img_hair)
|
248 |
+
# img_hair = torch.stack(img_hair_stack)
|
249 |
+
|
250 |
+
# return {
|
251 |
+
# 'hair_pose': pose1,
|
252 |
+
# 'img_hair': img_hair,
|
253 |
+
# 'bald_pose': pose2,
|
254 |
+
# 'img_non_hair': img_non_hair,
|
255 |
+
# 'ref_hair': ref_hair
|
256 |
+
# }
|
257 |
+
|
258 |
+
# def __len__(self):
|
259 |
+
# return self.len
|
260 |
+
|
261 |
+
class myDataset_unet(data.Dataset):
|
262 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
263 |
+
|
264 |
+
def __init__(self, train_data_dir):
|
265 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
266 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
267 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
268 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
269 |
+
self.lists = os.listdir(self.img_path)
|
270 |
+
self.len = len(self.lists)
|
271 |
+
self.pose = np.load(self.pose_path)
|
272 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
273 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
274 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
275 |
+
for i in Face_yaws:
|
276 |
+
if i<0:
|
277 |
+
i = 2*np.pi+i
|
278 |
+
i = i/2*np.pi*360
|
279 |
+
face_yaws = [Face_yaws[0]]
|
280 |
+
for i in range(20):
|
281 |
+
face_yaws.append(Face_yaws[3*i+2])
|
282 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
283 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
284 |
+
self.azimuths_rad[:-1].sort()
|
285 |
+
|
286 |
+
|
287 |
+
def __getitem__(self, index):
|
288 |
+
"""Returns one data pair (source and target)."""
|
289 |
+
# seq_len, fea_dim
|
290 |
+
random_number1 = random.randrange(0, 21)
|
291 |
+
random_number2 = random.randrange(0, 21)
|
292 |
+
|
293 |
+
# while random_number2 == random_number1:
|
294 |
+
# random_number2 = random.randrange(0, 21)
|
295 |
+
name = self.lists[index]
|
296 |
+
|
297 |
+
#random_number1 = random_number1
|
298 |
+
#random_number2 = random_number2 * 10
|
299 |
+
|
300 |
+
#random_number2 = random_number1
|
301 |
+
|
302 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
303 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
304 |
+
ref_folder = os.path.join(self.img_path, name)
|
305 |
+
|
306 |
+
#files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
307 |
+
ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
308 |
+
img_hair = cv2.imread(hair_path)
|
309 |
+
img_non_hair = cv2.imread(non_hair_path)
|
310 |
+
ref_hair = cv2.imread(ref_path)
|
311 |
+
|
312 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
313 |
+
# img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
314 |
+
# ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
315 |
+
|
316 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
317 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
318 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
319 |
+
|
320 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
321 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
322 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
323 |
+
|
324 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
325 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
326 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
327 |
+
|
328 |
+
# pose1 = self.pose[random_number1]
|
329 |
+
# pose1 = torch.tensor(pose1)
|
330 |
+
# pose2 = self.pose[random_number2]
|
331 |
+
# pose2 = torch.tensor(pose2)
|
332 |
+
# polars = self.polars_rad[random_number1]
|
333 |
+
# polars = torch.tensor(polars).unsqueeze(0)
|
334 |
+
# azimuths = self.azimuths_rad[random_number1]
|
335 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
336 |
+
pose = self.pose[random_number1]
|
337 |
+
pose = torch.tensor(pose)
|
338 |
+
|
339 |
+
return {
|
340 |
+
# 'hair_pose': pose1,
|
341 |
+
'img_hair': img_hair,
|
342 |
+
# 'bald_pose': pose2,
|
343 |
+
# 'img_non_hair': img_non_hair,
|
344 |
+
'img_ref': ref_hair,
|
345 |
+
'pose': pose,
|
346 |
+
# 'polars': polars,
|
347 |
+
# 'azimuths': azimuths,
|
348 |
+
}
|
349 |
+
|
350 |
+
def __len__(self):
|
351 |
+
return self.len-10
|
352 |
+
|
353 |
+
class myDataset_sv3d(data.Dataset):
|
354 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
355 |
+
|
356 |
+
def __init__(self, train_data_dir, frame_num=6):
|
357 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
358 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
359 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
360 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
361 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
362 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
363 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
364 |
+
self.lists = os.listdir(self.img_path)
|
365 |
+
self.len = len(self.lists)-10
|
366 |
+
self.pose = np.load(self.pose_path)
|
367 |
+
self.frame_num = frame_num
|
368 |
+
#self.pose = np.random.randn(12, 4)
|
369 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
370 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
371 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
372 |
+
for i in Face_yaws:
|
373 |
+
if i<0:
|
374 |
+
i = 2*np.pi+i
|
375 |
+
i = i/2*np.pi*360
|
376 |
+
face_yaws = [Face_yaws[0]]
|
377 |
+
for i in range(20):
|
378 |
+
face_yaws.append(Face_yaws[3*i+2])
|
379 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
380 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
381 |
+
self.azimuths_rad[:-1].sort()
|
382 |
+
|
383 |
+
def __getitem__(self, index):
|
384 |
+
"""Returns one data pair (source and target)."""
|
385 |
+
# seq_len, fea_dim
|
386 |
+
random_number1 = random.randrange(0, 21)
|
387 |
+
random_number3 = random.randrange(0, 21)
|
388 |
+
random_number2 = random.randrange(0, 21)
|
389 |
+
|
390 |
+
while random_number3 == random_number1:
|
391 |
+
random_number3 = random.randrange(0, 21)
|
392 |
+
|
393 |
+
# while random_number3 == random_number1:
|
394 |
+
# random_number3 = random.randrange(0, 21)
|
395 |
+
name = self.lists[index]
|
396 |
+
|
397 |
+
#random_number1 = random_number1
|
398 |
+
#* 10
|
399 |
+
#random_number2 = random_number2 * 10
|
400 |
+
|
401 |
+
#random_number2 = random_number1
|
402 |
+
|
403 |
+
|
404 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
405 |
+
#hair_path2 = os.path.join(self.img_path, name, str(random_number3) + '.jpg')
|
406 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
407 |
+
#non_hair_path2 = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
408 |
+
#non_hair_path3 = os.path.join(self.non_hair_path, name, str(random_number3) + '.jpg')
|
409 |
+
ref_folder = os.path.join(self.ref_path, name)
|
410 |
+
|
411 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
412 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
413 |
+
ref_path = os.path.join(ref_folder,files[0])
|
414 |
+
# print('________')
|
415 |
+
# print(files)
|
416 |
+
# print('++++++++')
|
417 |
+
# print(ref_folder)
|
418 |
+
# print("========")
|
419 |
+
# print(name)
|
420 |
+
# print("********")
|
421 |
+
# print(ref_path)
|
422 |
+
img_hair = cv2.imread(hair_path)
|
423 |
+
#img_hair2 = cv2.imread(hair_path2)
|
424 |
+
img_non_hair = cv2.imread(non_hair_path)
|
425 |
+
#img_non_hair2 = cv2.imread(non_hair_path2)
|
426 |
+
#img_non_hair3 = cv2.imread(non_hair_path3)
|
427 |
+
ref_hair = cv2.imread(ref_path)
|
428 |
+
|
429 |
+
|
430 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
431 |
+
#img_non_hair2 = cv2.cvtColor(img_non_hair2, cv2.COLOR_BGR2RGB)
|
432 |
+
#img_non_hair3 = cv2.cvtColor(img_non_hair3, cv2.COLOR_BGR2RGB)
|
433 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
434 |
+
|
435 |
+
|
436 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
437 |
+
#img_non_hair2 = cv2.resize(img_non_hair2, (512, 512))
|
438 |
+
#img_non_hair3 = cv2.resize(img_non_hair3, (512, 512))
|
439 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
440 |
+
|
441 |
+
|
442 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
443 |
+
#img_non_hair2 = (img_non_hair2 / 255.0) * 2 - 1
|
444 |
+
#img_non_hair3 = (img_non_hair3 / 255.0) * 2 - 1
|
445 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
446 |
+
|
447 |
+
|
448 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
449 |
+
#img_non_hair2 = torch.tensor(img_non_hair2).permute(2, 0, 1)
|
450 |
+
#img_non_hair3 = torch.tensor(img_non_hair3).permute(2, 0, 1)
|
451 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
452 |
+
|
453 |
+
pose = self.pose[random_number1]
|
454 |
+
pose = torch.tensor(pose)
|
455 |
+
pose2 = self.pose[random_number3]
|
456 |
+
pose2 = torch.tensor(pose2)
|
457 |
+
# pose2 = self.pose[random_number2]
|
458 |
+
# pose2 = torch.tensor(pose2)
|
459 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
460 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
461 |
+
# img_hair_stack = []
|
462 |
+
# polar = self.polars_rad[random_number1]
|
463 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
464 |
+
# azimuths = self.azimuths_rad[random_number1]
|
465 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
466 |
+
# img_hair = cv2.imread(hair_path)
|
467 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
468 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
469 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
470 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
471 |
+
|
472 |
+
#img_hair2 = cv2.cvtColor(img_hair2, cv2.COLOR_BGR2RGB)
|
473 |
+
#img_hair2 = cv2.resize(img_hair2, (512, 512))
|
474 |
+
#img_hair2 = (img_hair2 / 255.0) * 2 - 1
|
475 |
+
#img_hair2 = torch.tensor(img_hair2).permute(2, 0, 1)
|
476 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
477 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
478 |
+
# for i in hair_num:
|
479 |
+
# img_hair = cv2.imread(hair_path)
|
480 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
481 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
482 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
483 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
484 |
+
# img_hair_stack.append(img_hair)
|
485 |
+
# img_hair = torch.stack(img_hair_stack)
|
486 |
+
|
487 |
+
return {
|
488 |
+
# 'hair_pose': pose1,
|
489 |
+
'img_hair': img_hair,
|
490 |
+
#'img_hair2': img_hair2,
|
491 |
+
# 'bald_pose': pose2,
|
492 |
+
#'pose': pose,
|
493 |
+
#'pose2': pose2,
|
494 |
+
'img_non_hair': img_non_hair,
|
495 |
+
#'img_non_hair2': img_non_hair2,
|
496 |
+
#'img_non_hair3': img_non_hair3,
|
497 |
+
'ref_hair': ref_hair,
|
498 |
+
# 'polar': polar,
|
499 |
+
# 'azimuths':azimuths,
|
500 |
+
}
|
501 |
+
|
502 |
+
def __len__(self):
|
503 |
+
return self.len
|
504 |
+
|
505 |
+
class myDataset_sv3d2(data.Dataset):
|
506 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
507 |
+
|
508 |
+
def __init__(self, train_data_dir, frame_num=6):
|
509 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
510 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
511 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
512 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
513 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
514 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
515 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
516 |
+
self.lists = os.listdir(self.img_path)
|
517 |
+
self.len = len(self.lists)-10
|
518 |
+
self.pose = np.load(self.pose_path)
|
519 |
+
self.frame_num = frame_num
|
520 |
+
#self.pose = np.random.randn(12, 4)
|
521 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
522 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
523 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
524 |
+
for i in Face_yaws:
|
525 |
+
if i<0:
|
526 |
+
i = 2*np.pi+i
|
527 |
+
i = i/2*np.pi*360
|
528 |
+
face_yaws = [Face_yaws[0]]
|
529 |
+
for i in range(20):
|
530 |
+
face_yaws.append(Face_yaws[3*i+2])
|
531 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
532 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
533 |
+
self.azimuths_rad[:-1].sort()
|
534 |
+
|
535 |
+
def __getitem__(self, index):
|
536 |
+
"""Returns one data pair (source and target)."""
|
537 |
+
# seq_len, fea_dim
|
538 |
+
random_number1 = random.randrange(0, 21)
|
539 |
+
random_number2 = random.randrange(0, 21)
|
540 |
+
|
541 |
+
# while random_number2 == random_number1:
|
542 |
+
# random_number2 = random.randrange(0, 21)
|
543 |
+
name = self.lists[index]
|
544 |
+
|
545 |
+
#random_number1 = random_number1
|
546 |
+
#* 10
|
547 |
+
#random_number2 = random_number2 * 10
|
548 |
+
|
549 |
+
#random_number2 = random_number1
|
550 |
+
|
551 |
+
|
552 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
553 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
554 |
+
ref_folder = os.path.join(self.ref_path, name)
|
555 |
+
|
556 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
557 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
558 |
+
ref_path = os.path.join(ref_folder,files[0])
|
559 |
+
# print('________')
|
560 |
+
# print(files)
|
561 |
+
# print('++++++++')
|
562 |
+
# print(ref_folder)
|
563 |
+
# print("========")
|
564 |
+
# print(name)
|
565 |
+
# print("********")
|
566 |
+
# print(ref_path)
|
567 |
+
img_hair = cv2.imread(hair_path)
|
568 |
+
img_non_hair = cv2.imread(non_hair_path)
|
569 |
+
ref_hair = cv2.imread(ref_path)
|
570 |
+
|
571 |
+
|
572 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
573 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
574 |
+
|
575 |
+
|
576 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
577 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
578 |
+
|
579 |
+
|
580 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
581 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
582 |
+
|
583 |
+
|
584 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
585 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
586 |
+
|
587 |
+
pose = self.pose[random_number1]
|
588 |
+
pose = torch.tensor(pose)
|
589 |
+
# pose2 = self.pose[random_number2]
|
590 |
+
# pose2 = torch.tensor(pose2)
|
591 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
592 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
593 |
+
# img_hair_stack = []
|
594 |
+
# polar = self.polars_rad[random_number1]
|
595 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
596 |
+
# azimuths = self.azimuths_rad[random_number1]
|
597 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
598 |
+
img_hair = cv2.imread(hair_path)
|
599 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
600 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
601 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
602 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
603 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
604 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
605 |
+
# for i in hair_num:
|
606 |
+
# img_hair = cv2.imread(hair_path)
|
607 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
608 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
609 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
610 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
611 |
+
# img_hair_stack.append(img_hair)
|
612 |
+
# img_hair = torch.stack(img_hair_stack)
|
613 |
+
|
614 |
+
return {
|
615 |
+
# 'hair_pose': pose1,
|
616 |
+
'img_hair': img_hair,
|
617 |
+
# 'bald_pose': pose2,
|
618 |
+
'pose': pose,
|
619 |
+
'img_non_hair': img_non_hair,
|
620 |
+
'ref_hair': ref_hair,
|
621 |
+
# 'polar': polar,
|
622 |
+
# 'azimuths':azimuths,
|
623 |
+
}
|
624 |
+
|
625 |
+
def __len__(self):
|
626 |
+
return self.len
|
627 |
+
|
628 |
+
class myDataset_sv3d_temporal(data.Dataset):
|
629 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
630 |
+
|
631 |
+
def __init__(self, train_data_dir, frame_num=6):
|
632 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
633 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
634 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
635 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
636 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
637 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
638 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
639 |
+
self.lists = os.listdir(self.img_path)
|
640 |
+
self.len = len(self.lists)-10
|
641 |
+
self.pose = np.load(self.pose_path)
|
642 |
+
self.frame_num = frame_num
|
643 |
+
#self.pose = np.random.randn(12, 4)
|
644 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
645 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
646 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
647 |
+
for i in Face_yaws:
|
648 |
+
if i<0:
|
649 |
+
i = 2*np.pi+i
|
650 |
+
i = i/2*np.pi*360
|
651 |
+
face_yaws = [Face_yaws[0]]
|
652 |
+
for i in range(20):
|
653 |
+
face_yaws.append(Face_yaws[3*i+2])
|
654 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
655 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
656 |
+
self.azimuths_rad[:-1].sort()
|
657 |
+
|
658 |
+
def read_img(self, path):
|
659 |
+
img = cv2.imread(path)
|
660 |
+
|
661 |
+
img = cv2.resize(img, (512, 512))
|
662 |
+
img = (img / 255.0) * 2 - 1
|
663 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
664 |
+
return img
|
665 |
+
|
666 |
+
|
667 |
+
def __getitem__(self, index):
|
668 |
+
"""Returns one data pair (source and target)."""
|
669 |
+
# seq_len, fea_dim
|
670 |
+
random_number1 = random.randrange(0, 21-10)
|
671 |
+
# random_number3 = random.randrange(0, 21)
|
672 |
+
random_number2 = random.randrange(0, 21)
|
673 |
+
|
674 |
+
while random_number3 == random_number1:
|
675 |
+
random_number3 = random.randrange(0, 21)
|
676 |
+
|
677 |
+
# while random_number3 == random_number1:
|
678 |
+
# random_number3 = random.randrange(0, 21)
|
679 |
+
name = self.lists[index]
|
680 |
+
x_stack = []
|
681 |
+
y_stack = []
|
682 |
+
img_non_hair_stack = []
|
683 |
+
img_hair_stack = []
|
684 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
685 |
+
ref_folder = os.path.join(self.ref_path, name)
|
686 |
+
|
687 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
688 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
689 |
+
ref_path = os.path.join(ref_folder,files[0])
|
690 |
+
for i in range(10):
|
691 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(1))
|
692 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
693 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(1))
|
694 |
+
|
695 |
+
#random_number1 = random_number1
|
696 |
+
#* 10
|
697 |
+
#random_number2 = random_number2 * 10
|
698 |
+
|
699 |
+
#random_number2 = random_number1
|
700 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
701 |
+
|
702 |
+
|
703 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
704 |
+
hair_path2 = os.path.join(self.img_path, name, str(random_number3) + '.jpg')
|
705 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
706 |
+
non_hair_path2 = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
707 |
+
non_hair_path3 = os.path.join(self.non_hair_path, name, str(random_number3) + '.jpg')
|
708 |
+
ref_folder = os.path.join(self.ref_path, name)
|
709 |
+
|
710 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
711 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
712 |
+
ref_path = os.path.join(ref_folder,files[0])
|
713 |
+
# print('________')
|
714 |
+
# print(files)
|
715 |
+
# print('++++++++')
|
716 |
+
# print(ref_folder)
|
717 |
+
# print("========")
|
718 |
+
# print(name)
|
719 |
+
# print("********")
|
720 |
+
# print(ref_path)
|
721 |
+
img_hair = cv2.imread(hair_path)
|
722 |
+
img_hair2 = cv2.imread(hair_path2)
|
723 |
+
img_non_hair = cv2.imread(non_hair_path)
|
724 |
+
img_non_hair2 = cv2.imread(non_hair_path2)
|
725 |
+
img_non_hair3 = cv2.imread(non_hair_path3)
|
726 |
+
ref_hair = cv2.imread(ref_path)
|
727 |
+
|
728 |
+
|
729 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
730 |
+
img_non_hair2 = cv2.cvtColor(img_non_hair2, cv2.COLOR_BGR2RGB)
|
731 |
+
img_non_hair3 = cv2.cvtColor(img_non_hair3, cv2.COLOR_BGR2RGB)
|
732 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
733 |
+
|
734 |
+
|
735 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
736 |
+
img_non_hair2 = cv2.resize(img_non_hair2, (512, 512))
|
737 |
+
img_non_hair3 = cv2.resize(img_non_hair3, (512, 512))
|
738 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
739 |
+
|
740 |
+
|
741 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
742 |
+
img_non_hair2 = (img_non_hair2 / 255.0) * 2 - 1
|
743 |
+
img_non_hair3 = (img_non_hair3 / 255.0) * 2 - 1
|
744 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
745 |
+
|
746 |
+
|
747 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
748 |
+
img_non_hair2 = torch.tensor(img_non_hair2).permute(2, 0, 1)
|
749 |
+
img_non_hair3 = torch.tensor(img_non_hair3).permute(2, 0, 1)
|
750 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
751 |
+
|
752 |
+
pose = self.pose[random_number1]
|
753 |
+
pose = torch.tensor(pose)
|
754 |
+
pose2 = self.pose[random_number3]
|
755 |
+
pose2 = torch.tensor(pose2)
|
756 |
+
# pose2 = self.pose[random_number2]
|
757 |
+
# pose2 = torch.tensor(pose2)
|
758 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
759 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
760 |
+
# img_hair_stack = []
|
761 |
+
# polar = self.polars_rad[random_number1]
|
762 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
763 |
+
# azimuths = self.azimuths_rad[random_number1]
|
764 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
765 |
+
# img_hair = cv2.imread(hair_path)
|
766 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
767 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
768 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
769 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
770 |
+
|
771 |
+
img_hair2 = cv2.cvtColor(img_hair2, cv2.COLOR_BGR2RGB)
|
772 |
+
img_hair2 = cv2.resize(img_hair2, (512, 512))
|
773 |
+
img_hair2 = (img_hair2 / 255.0) * 2 - 1
|
774 |
+
img_hair2 = torch.tensor(img_hair2).permute(2, 0, 1)
|
775 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
776 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
777 |
+
# for i in hair_num:
|
778 |
+
# img_hair = cv2.imread(hair_path)
|
779 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
780 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
781 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
782 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
783 |
+
# img_hair_stack.append(img_hair)
|
784 |
+
# img_hair = torch.stack(img_hair_stack)
|
785 |
+
|
786 |
+
return {
|
787 |
+
# 'hair_pose': pose1,
|
788 |
+
'img_hair': img_hair,
|
789 |
+
'img_hair2': img_hair2,
|
790 |
+
# 'bald_pose': pose2,
|
791 |
+
'pose': pose,
|
792 |
+
'pose2': pose2,
|
793 |
+
'img_non_hair': img_non_hair,
|
794 |
+
'img_non_hair2': img_non_hair2,
|
795 |
+
'img_non_hair3': img_non_hair3,
|
796 |
+
'ref_hair': ref_hair,
|
797 |
+
# 'polar': polar,
|
798 |
+
# 'azimuths':azimuths,
|
799 |
+
}
|
800 |
+
|
801 |
+
def __len__(self):
|
802 |
+
return self.len
|
803 |
+
|
804 |
+
class myDataset_sv3d_simple(data.Dataset):
|
805 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
806 |
+
|
807 |
+
def __init__(self, train_data_dir, frame_num=6):
|
808 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
809 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
810 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
811 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
812 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
813 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
814 |
+
# self.ref_path = os.path.join(train_data_dir, "reference")
|
815 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
816 |
+
self.lists = os.listdir(self.img_path)
|
817 |
+
self.len = len(self.lists)-10
|
818 |
+
self.pose = np.load(self.pose_path)
|
819 |
+
self.frame_num = frame_num
|
820 |
+
#self.pose = np.random.randn(12, 4)
|
821 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
822 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
823 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
824 |
+
for i in Face_yaws:
|
825 |
+
if i<0:
|
826 |
+
i = 2*np.pi+i
|
827 |
+
i = i/2*np.pi*360
|
828 |
+
face_yaws = [Face_yaws[0]]
|
829 |
+
for i in range(20):
|
830 |
+
face_yaws.append(Face_yaws[3*i+2])
|
831 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
832 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
833 |
+
self.azimuths_rad[:-1].sort()
|
834 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
835 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
836 |
+
self.x = [x[0]]
|
837 |
+
self.y = [y[0]]
|
838 |
+
for i in range(20):
|
839 |
+
self.x.append(x[i*3+2])
|
840 |
+
self.y.append(y[i*3+2])
|
841 |
+
|
842 |
+
def __getitem__(self, index):
|
843 |
+
"""Returns one data pair (source and target)."""
|
844 |
+
# seq_len, fea_dim
|
845 |
+
random_number1 = random.randrange(0, 21)
|
846 |
+
random_number2 = random.randrange(0, 21)
|
847 |
+
|
848 |
+
# while random_number2 == random_number1:
|
849 |
+
# random_number2 = random.randrange(0, 21)
|
850 |
+
name = self.lists[index]
|
851 |
+
|
852 |
+
#random_number1 = random_number1
|
853 |
+
#* 10
|
854 |
+
#random_number2 = random_number2 * 10
|
855 |
+
|
856 |
+
random_number2 = random_number1
|
857 |
+
|
858 |
+
|
859 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
860 |
+
# hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
861 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
862 |
+
ref_folder = os.path.join(self.ref_path, name)
|
863 |
+
|
864 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
865 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
866 |
+
ref_path = os.path.join(ref_folder,files[0])
|
867 |
+
# print('________')
|
868 |
+
# print(files)
|
869 |
+
# print('++++++++')
|
870 |
+
# print(ref_folder)
|
871 |
+
# print("========")
|
872 |
+
# print(name)
|
873 |
+
# print("********")
|
874 |
+
# print(ref_path)
|
875 |
+
img_hair = cv2.imread(hair_path)
|
876 |
+
img_non_hair = cv2.imread(non_hair_path)
|
877 |
+
ref_hair = cv2.imread(ref_path)
|
878 |
+
|
879 |
+
|
880 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
881 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
882 |
+
|
883 |
+
|
884 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
885 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
886 |
+
|
887 |
+
|
888 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
889 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
890 |
+
|
891 |
+
|
892 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
893 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
894 |
+
|
895 |
+
pose = self.pose[random_number1]
|
896 |
+
pose = torch.tensor(pose)
|
897 |
+
# pose2 = self.pose[random_number2]
|
898 |
+
# pose2 = torch.tensor(pose2)
|
899 |
+
# hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
900 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
901 |
+
# img_hair_stack = []
|
902 |
+
# polar = self.polars_rad[random_number1]
|
903 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
904 |
+
# azimuths = self.azimuths_rad[random_number1]
|
905 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
906 |
+
img_hair = cv2.imread(hair_path)
|
907 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
908 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
909 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
910 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
911 |
+
x = torch.tensor(self.x[random_number1])
|
912 |
+
y = torch.tensor(self.y[random_number1])
|
913 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
914 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
915 |
+
# for i in hair_num:
|
916 |
+
# img_hair = cv2.imread(hair_path)
|
917 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
918 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
919 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
920 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
921 |
+
# img_hair_stack.append(img_hair)
|
922 |
+
# img_hair = torch.stack(img_hair_stack)
|
923 |
+
|
924 |
+
return {
|
925 |
+
# 'hair_pose': pose1,
|
926 |
+
'img_hair': img_hair,
|
927 |
+
# 'bald_pose': pose2,
|
928 |
+
'pose': pose,
|
929 |
+
'img_non_hair': img_non_hair,
|
930 |
+
'ref_hair': ref_hair,
|
931 |
+
'x': x,
|
932 |
+
'y': y,
|
933 |
+
# 'polar': polar,
|
934 |
+
# 'azimuths':azimuths,
|
935 |
+
}
|
936 |
+
|
937 |
+
def __len__(self):
|
938 |
+
return self.len
|
939 |
+
|
940 |
+
class myDataset_sv3d_simple_ori(data.Dataset):
|
941 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
942 |
+
|
943 |
+
def __init__(self, train_data_dir, frame_num=6):
|
944 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
945 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
946 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
947 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
948 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
949 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
950 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
951 |
+
self.lists = os.listdir(self.img_path)
|
952 |
+
self.len = len(self.lists)-10
|
953 |
+
self.pose = np.load(self.pose_path)
|
954 |
+
self.frame_num = frame_num
|
955 |
+
#self.pose = np.random.randn(12, 4)
|
956 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
957 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
958 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
959 |
+
for i in Face_yaws:
|
960 |
+
if i<0:
|
961 |
+
i = 2*np.pi+i
|
962 |
+
i = i/2*np.pi*360
|
963 |
+
face_yaws = [Face_yaws[0]]
|
964 |
+
for i in range(20):
|
965 |
+
face_yaws.append(Face_yaws[3*i+2])
|
966 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
967 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
968 |
+
self.azimuths_rad[:-1].sort()
|
969 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
970 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
971 |
+
self.x = [x[0]]
|
972 |
+
self.y = [y[0]]
|
973 |
+
for i in range(20):
|
974 |
+
self.x.append(x[i*3+2])
|
975 |
+
self.y.append(y[i*3+2])
|
976 |
+
|
977 |
+
def __getitem__(self, index):
|
978 |
+
"""Returns one data pair (source and target)."""
|
979 |
+
# seq_len, fea_dim
|
980 |
+
random_number1 = random.randrange(0, 21)
|
981 |
+
random_number2 = random.randrange(0, 21)
|
982 |
+
|
983 |
+
# while random_number2 == random_number1:
|
984 |
+
# random_number2 = random.randrange(0, 21)
|
985 |
+
name = self.lists[index]
|
986 |
+
|
987 |
+
#random_number1 = random_number1
|
988 |
+
#* 10
|
989 |
+
#random_number2 = random_number2 * 10
|
990 |
+
|
991 |
+
#random_number2 = random_number1
|
992 |
+
|
993 |
+
|
994 |
+
hair_path = os.path.join(self.img_path, name, str(random_number2) + '.jpg')
|
995 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
996 |
+
ref_folder = os.path.join(self.ref_path, name)
|
997 |
+
|
998 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
999 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1000 |
+
ref_path = os.path.join(ref_folder,files[0])
|
1001 |
+
# print('________')
|
1002 |
+
# print(files)
|
1003 |
+
# print('++++++++')
|
1004 |
+
# print(ref_folder)
|
1005 |
+
# print("========")
|
1006 |
+
# print(name)
|
1007 |
+
# print("********")
|
1008 |
+
# print(ref_path)
|
1009 |
+
img_hair = cv2.imread(hair_path)
|
1010 |
+
img_non_hair = cv2.imread(non_hair_path)
|
1011 |
+
ref_hair = cv2.imread(ref_path)
|
1012 |
+
|
1013 |
+
|
1014 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
1015 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
1016 |
+
|
1017 |
+
|
1018 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
1019 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
1020 |
+
|
1021 |
+
|
1022 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
1023 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
1024 |
+
|
1025 |
+
|
1026 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
1027 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
1028 |
+
|
1029 |
+
pose = self.pose[random_number2]
|
1030 |
+
pose = torch.tensor(pose)
|
1031 |
+
# pose2 = self.pose[random_number2]
|
1032 |
+
# pose2 = torch.tensor(pose2)
|
1033 |
+
hair_path = os.path.join(self.img_path, name, str(random_number2) + '.jpg')
|
1034 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
1035 |
+
# img_hair_stack = []
|
1036 |
+
# polar = self.polars_rad[random_number1]
|
1037 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
1038 |
+
# azimuths = self.azimuths_rad[random_number1]
|
1039 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
1040 |
+
img_hair = cv2.imread(hair_path)
|
1041 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
1042 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
1043 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
1044 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
1045 |
+
x = torch.tensor(self.x[random_number2])
|
1046 |
+
y = torch.tensor(self.y[random_number2])
|
1047 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
1048 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
1049 |
+
# for i in hair_num:
|
1050 |
+
# img_hair = cv2.imread(hair_path)
|
1051 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
1052 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
1053 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
1054 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
1055 |
+
# img_hair_stack.append(img_hair)
|
1056 |
+
# img_hair = torch.stack(img_hair_stack)
|
1057 |
+
|
1058 |
+
return {
|
1059 |
+
# 'hair_pose': pose1,
|
1060 |
+
'img_hair': img_hair,
|
1061 |
+
# 'bald_pose': pose2,
|
1062 |
+
'pose': pose,
|
1063 |
+
'img_non_hair': img_non_hair,
|
1064 |
+
'ref_hair': ref_hair,
|
1065 |
+
'x': x,
|
1066 |
+
'y': y,
|
1067 |
+
# 'polar': polar,
|
1068 |
+
# 'azimuths':azimuths,
|
1069 |
+
}
|
1070 |
+
|
1071 |
+
def __len__(self):
|
1072 |
+
return self.len
|
1073 |
+
|
1074 |
+
class myDataset_sv3d_simple_temporal(data.Dataset):
|
1075 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1076 |
+
|
1077 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1078 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1079 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1080 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1081 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1082 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1083 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1084 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
1085 |
+
self.lists = os.listdir(self.img_path)
|
1086 |
+
self.len = len(self.lists)-10
|
1087 |
+
self.pose = np.load(self.pose_path)
|
1088 |
+
self.frame_num = frame_num
|
1089 |
+
#self.pose = np.random.randn(12, 4)
|
1090 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1091 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1092 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1093 |
+
for i in Face_yaws:
|
1094 |
+
if i<0:
|
1095 |
+
i = 2*np.pi+i
|
1096 |
+
i = i/2*np.pi*360
|
1097 |
+
face_yaws = [Face_yaws[0]]
|
1098 |
+
for i in range(20):
|
1099 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1100 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1101 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1102 |
+
self.azimuths_rad[:-1].sort()
|
1103 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1104 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1105 |
+
self.x = [x[0]]
|
1106 |
+
self.y = [y[0]]
|
1107 |
+
for i in range(20):
|
1108 |
+
self.x.append(x[i*3+2])
|
1109 |
+
self.y.append(y[i*3+2])
|
1110 |
+
|
1111 |
+
def read_img(self, path):
|
1112 |
+
img = cv2.imread(path)
|
1113 |
+
|
1114 |
+
img = cv2.resize(img, (512, 512))
|
1115 |
+
img = (img / 255.0) * 2 - 1
|
1116 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1117 |
+
return img
|
1118 |
+
|
1119 |
+
def __getitem__(self, index):
|
1120 |
+
"""Returns one data pair (source and target)."""
|
1121 |
+
# seq_len, fea_dim
|
1122 |
+
random_number1 = random.randrange(0, 21-12)
|
1123 |
+
random_number2 = random.randrange(0, 21)
|
1124 |
+
|
1125 |
+
name = self.lists[index]
|
1126 |
+
x_stack = []
|
1127 |
+
y_stack = []
|
1128 |
+
img_non_hair_stack = []
|
1129 |
+
img_hair_stack = []
|
1130 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1131 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1132 |
+
|
1133 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1134 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1135 |
+
ref_path = os.path.join(ref_folder,files[0])
|
1136 |
+
ref_hair = self.read_img(ref_path)
|
1137 |
+
for i in range(12):
|
1138 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1139 |
+
#hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1140 |
+
hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
|
1141 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1142 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1143 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1144 |
+
|
1145 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1146 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1147 |
+
x = torch.cat(x_stack, axis=0)
|
1148 |
+
y = torch.cat(y_stack, axis=0)
|
1149 |
+
|
1150 |
+
return {
|
1151 |
+
'img_hair': img_hair,
|
1152 |
+
'img_non_hair': img_non_hair,
|
1153 |
+
'ref_hair': ref_hair,
|
1154 |
+
'x': x,
|
1155 |
+
'y': y,
|
1156 |
+
|
1157 |
+
}
|
1158 |
+
|
1159 |
+
def __len__(self):
|
1160 |
+
return self.len
|
1161 |
+
|
1162 |
+
class myDataset_sv3d_simple_temporal2(data.Dataset):
|
1163 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1164 |
+
|
1165 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1166 |
+
train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
|
1167 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1168 |
+
self.img_path2 = os.path.join(train_data_dir, "hair_good")
|
1169 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1170 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1171 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1172 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1173 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1174 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1175 |
+
|
1176 |
+
self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
|
1177 |
+
self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
|
1178 |
+
self.ref_path2 = os.path.join(train_data_dir2, "reference")
|
1179 |
+
|
1180 |
+
self.lists = os.listdir(self.img_path2)
|
1181 |
+
self.len = len(self.lists)-10
|
1182 |
+
self.pose = np.load(self.pose_path)
|
1183 |
+
self.frame_num = frame_num
|
1184 |
+
#self.pose = np.random.randn(12, 4)
|
1185 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1186 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1187 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1188 |
+
for i in Face_yaws:
|
1189 |
+
if i<0:
|
1190 |
+
i = 2*np.pi+i
|
1191 |
+
i = i/2*np.pi*360
|
1192 |
+
face_yaws = [Face_yaws[0]]
|
1193 |
+
for i in range(20):
|
1194 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1195 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1196 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1197 |
+
self.azimuths_rad[:-1].sort()
|
1198 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1199 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1200 |
+
self.x = [x[0]]
|
1201 |
+
self.y = [y[0]]
|
1202 |
+
for i in range(20):
|
1203 |
+
self.x.append(x[i*3+2])
|
1204 |
+
self.y.append(y[i*3+2])
|
1205 |
+
|
1206 |
+
def read_img(self, path):
|
1207 |
+
img = cv2.imread(path)
|
1208 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
1209 |
+
img = cv2.resize(img, (512, 512))
|
1210 |
+
img = (img / 255.0) * 2 - 1
|
1211 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1212 |
+
return img
|
1213 |
+
|
1214 |
+
def read_ref_img(self, path):
|
1215 |
+
img = cv2.imread(path)
|
1216 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
1217 |
+
img = hair_transform(image=img)['image']
|
1218 |
+
# img = cv2.resize(img, (512, 512))
|
1219 |
+
img = (img / 255.0) * 2 - 1
|
1220 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1221 |
+
return img
|
1222 |
+
|
1223 |
+
def reference_lists(self, reference_num, root):
|
1224 |
+
stacks = []
|
1225 |
+
invalid = []
|
1226 |
+
for i in range(12):
|
1227 |
+
if (reference_num-6+i)<0:
|
1228 |
+
invalid.append(reference_num-5+i+21)
|
1229 |
+
else:
|
1230 |
+
invalid.append((reference_num-5+i)%21)
|
1231 |
+
for i in range(21):
|
1232 |
+
if i in invalid:
|
1233 |
+
continue
|
1234 |
+
else:
|
1235 |
+
stacks.append(os.path.join(root, str(i)+'.jpg'))
|
1236 |
+
return stacks
|
1237 |
+
|
1238 |
+
def __getitem__(self, index):
|
1239 |
+
"""Returns one data pair (source and target)."""
|
1240 |
+
# seq_len, fea_dim
|
1241 |
+
random_number = random.uniform(0, 1)
|
1242 |
+
if random_number<0.5:
|
1243 |
+
non_hair_root = self.non_hair_path2
|
1244 |
+
img_path = self.img_path2
|
1245 |
+
else:
|
1246 |
+
non_hair_root = self.non_hair_path
|
1247 |
+
img_path = self.img_path
|
1248 |
+
random_number1 = random.randrange(0, 21-12)
|
1249 |
+
random_number2 = random.randrange(0, 21)
|
1250 |
+
|
1251 |
+
name = self.lists[index].split('.')[0]
|
1252 |
+
x_stack = []
|
1253 |
+
y_stack = []
|
1254 |
+
img_non_hair_stack = []
|
1255 |
+
img_hair_stack = []
|
1256 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1257 |
+
# non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
|
1258 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1259 |
+
|
1260 |
+
# files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3] + self.reference_lists(random_number2, os.path.join(self.img_path, name))[:5]
|
1261 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
|
1262 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1263 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1264 |
+
ref_hair = self.read_ref_img(ref_path)
|
1265 |
+
for i in range(12):
|
1266 |
+
#non_hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
|
1267 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1268 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1269 |
+
# hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
|
1270 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1271 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1272 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1273 |
+
|
1274 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1275 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1276 |
+
x = torch.cat(x_stack, axis=0)
|
1277 |
+
y = torch.cat(y_stack, axis=0)
|
1278 |
+
|
1279 |
+
return {
|
1280 |
+
'img_hair': img_hair,
|
1281 |
+
'img_non_hair': img_non_hair,
|
1282 |
+
'ref_hair': ref_hair,
|
1283 |
+
'x': x,
|
1284 |
+
'y': y,
|
1285 |
+
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
def __len__(self):
|
1289 |
+
return self.len
|
1290 |
+
|
1291 |
+
class myDataset_sv3d_simple_temporal3(data.Dataset):
|
1292 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1293 |
+
|
1294 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1295 |
+
train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
|
1296 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1297 |
+
self.img_path2 = os.path.join(train_data_dir, "non-hair")
|
1298 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1299 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1300 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1301 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1302 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1303 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1304 |
+
|
1305 |
+
self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
|
1306 |
+
self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
|
1307 |
+
self.ref_path2 = os.path.join(train_data_dir2, "reference")
|
1308 |
+
|
1309 |
+
self.lists = os.listdir(self.img_path)
|
1310 |
+
self.len = len(self.lists)-10
|
1311 |
+
self.pose = np.load(self.pose_path)
|
1312 |
+
self.frame_num = frame_num
|
1313 |
+
#self.pose = np.random.randn(12, 4)
|
1314 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1315 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1316 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1317 |
+
for i in Face_yaws:
|
1318 |
+
if i<0:
|
1319 |
+
i = 2*np.pi+i
|
1320 |
+
i = i/2*np.pi*360
|
1321 |
+
face_yaws = [Face_yaws[0]]
|
1322 |
+
for i in range(20):
|
1323 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1324 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1325 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1326 |
+
self.azimuths_rad[:-1].sort()
|
1327 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1328 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1329 |
+
self.x = [x[0]]
|
1330 |
+
self.y = [y[0]]
|
1331 |
+
for i in range(20):
|
1332 |
+
self.x.append(x[i*3+2])
|
1333 |
+
self.y.append(y[i*3+2])
|
1334 |
+
|
1335 |
+
def read_img(self, path):
|
1336 |
+
img = cv2.imread(path)
|
1337 |
+
|
1338 |
+
img = cv2.resize(img, (512, 512))
|
1339 |
+
img = (img / 255.0) * 2 - 1
|
1340 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1341 |
+
return img
|
1342 |
+
|
1343 |
+
def __getitem__(self, index):
|
1344 |
+
"""Returns one data pair (source and target)."""
|
1345 |
+
# seq_len, fea_dim
|
1346 |
+
random_number = random.uniform(0, 1)
|
1347 |
+
if random_number<0.5:
|
1348 |
+
non_hair_root = self.non_hair_path2
|
1349 |
+
img_path = self.img_path2
|
1350 |
+
else:
|
1351 |
+
non_hair_root = self.non_hair_path
|
1352 |
+
img_path = self.img_path
|
1353 |
+
random_number1 = random.randrange(0, 21-12)
|
1354 |
+
random_number2 = random.randrange(0, 21)
|
1355 |
+
|
1356 |
+
name = self.lists[index]
|
1357 |
+
x_stack = []
|
1358 |
+
y_stack = []
|
1359 |
+
img_non_hair_stack = []
|
1360 |
+
img_hair_stack = []
|
1361 |
+
# non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1362 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1363 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1364 |
+
|
1365 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1366 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1367 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1368 |
+
ref_hair = self.read_img(ref_path)
|
1369 |
+
for i in range(12):
|
1370 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
|
1371 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1372 |
+
# hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1373 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1374 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1375 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1376 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1377 |
+
|
1378 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1379 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1380 |
+
x = torch.cat(x_stack, axis=0)
|
1381 |
+
y = torch.cat(y_stack, axis=0)
|
1382 |
+
|
1383 |
+
return {
|
1384 |
+
'img_hair': img_hair,
|
1385 |
+
'img_non_hair': img_non_hair,
|
1386 |
+
'ref_hair': ref_hair,
|
1387 |
+
'x': x,
|
1388 |
+
'y': y,
|
1389 |
+
|
1390 |
+
}
|
1391 |
+
|
1392 |
+
def __len__(self):
|
1393 |
+
return self.len
|
1394 |
+
|
1395 |
+
|
1396 |
+
class myDataset_sv3d_simple_temporal_controlnet_without_pose(data.Dataset):
|
1397 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1398 |
+
|
1399 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1400 |
+
train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
|
1401 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1402 |
+
self.img_path2 = os.path.join(train_data_dir, "non-hair")
|
1403 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1404 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1405 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1406 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1407 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1408 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1409 |
+
|
1410 |
+
self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
|
1411 |
+
self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
|
1412 |
+
self.ref_path2 = os.path.join(train_data_dir2, "reference")
|
1413 |
+
|
1414 |
+
self.lists = os.listdir(self.img_path)
|
1415 |
+
self.len = len(self.lists)-10
|
1416 |
+
self.pose = np.load(self.pose_path)
|
1417 |
+
self.frame_num = frame_num
|
1418 |
+
#self.pose = np.random.randn(12, 4)
|
1419 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1420 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1421 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1422 |
+
for i in Face_yaws:
|
1423 |
+
if i<0:
|
1424 |
+
i = 2*np.pi+i
|
1425 |
+
i = i/2*np.pi*360
|
1426 |
+
face_yaws = [Face_yaws[0]]
|
1427 |
+
for i in range(20):
|
1428 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1429 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1430 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1431 |
+
self.azimuths_rad[:-1].sort()
|
1432 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1433 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1434 |
+
self.x = [x[0]]
|
1435 |
+
self.y = [y[0]]
|
1436 |
+
for i in range(20):
|
1437 |
+
self.x.append(x[i*3+2])
|
1438 |
+
self.y.append(y[i*3+2])
|
1439 |
+
|
1440 |
+
def read_img(self, path):
|
1441 |
+
img = cv2.imread(path)
|
1442 |
+
|
1443 |
+
img = cv2.resize(img, (512, 512))
|
1444 |
+
img = (img / 255.0) * 2 - 1
|
1445 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1446 |
+
return img
|
1447 |
+
|
1448 |
+
def __getitem__(self, index):
|
1449 |
+
"""Returns one data pair (source and target)."""
|
1450 |
+
# seq_len, fea_dim
|
1451 |
+
random_number = random.uniform(0, 1)
|
1452 |
+
if random_number<0.5:
|
1453 |
+
non_hair_root = self.non_hair_path2
|
1454 |
+
img_path = self.img_path2
|
1455 |
+
else:
|
1456 |
+
non_hair_root = self.non_hair_path
|
1457 |
+
img_path = self.img_path
|
1458 |
+
random_number1 = random.randrange(0, 21-12)
|
1459 |
+
random_number2 = random.randrange(0, 21)
|
1460 |
+
|
1461 |
+
name = self.lists[index]
|
1462 |
+
x_stack = []
|
1463 |
+
y_stack = []
|
1464 |
+
img_non_hair_stack = []
|
1465 |
+
img_hair_stack = []
|
1466 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1467 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1468 |
+
|
1469 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1470 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1471 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1472 |
+
ref_hair = self.read_img(ref_path)
|
1473 |
+
for i in range(12):
|
1474 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number1+i) + '.jpg')
|
1475 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1476 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1477 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1478 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1479 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1480 |
+
|
1481 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1482 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1483 |
+
x = torch.cat(x_stack, axis=0)
|
1484 |
+
y = torch.cat(y_stack, axis=0)
|
1485 |
+
|
1486 |
+
return {
|
1487 |
+
'img_hair': img_hair,
|
1488 |
+
'img_non_hair': img_non_hair,
|
1489 |
+
'ref_hair': ref_hair,
|
1490 |
+
'x': x,
|
1491 |
+
'y': y,
|
1492 |
+
|
1493 |
+
}
|
1494 |
+
|
1495 |
+
def __len__(self):
|
1496 |
+
return self.len
|
1497 |
+
|
1498 |
+
|
1499 |
+
class myDataset_sv3d_simple_temporal_controlnet(data.Dataset):
|
1500 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1501 |
+
|
1502 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1503 |
+
train_data_dir2 = '/opt/liblibai-models/user-workspace/zyx/sky/3dhair/data/segement'
|
1504 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1505 |
+
self.img_path2 = os.path.join(train_data_dir, "non-hair")
|
1506 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1507 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1508 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1509 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1510 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1511 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1512 |
+
|
1513 |
+
self.pose_path2 = os.path.join(train_data_dir2, "pose.npy")
|
1514 |
+
self.non_hair_path2 = os.path.join(train_data_dir2, "non-hair")
|
1515 |
+
self.ref_path2 = os.path.join(train_data_dir2, "reference")
|
1516 |
+
|
1517 |
+
self.lists = os.listdir(self.img_path)
|
1518 |
+
self.len = len(self.lists)-10
|
1519 |
+
self.pose = np.load(self.pose_path)
|
1520 |
+
self.frame_num = frame_num
|
1521 |
+
#self.pose = np.random.randn(12, 4)
|
1522 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1523 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1524 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1525 |
+
for i in Face_yaws:
|
1526 |
+
if i<0:
|
1527 |
+
i = 2*np.pi+i
|
1528 |
+
i = i/2*np.pi*360
|
1529 |
+
face_yaws = [Face_yaws[0]]
|
1530 |
+
for i in range(20):
|
1531 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1532 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1533 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1534 |
+
self.azimuths_rad[:-1].sort()
|
1535 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1536 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1537 |
+
self.x = [x[0]]
|
1538 |
+
self.y = [y[0]]
|
1539 |
+
for i in range(20):
|
1540 |
+
self.x.append(x[i*3+2])
|
1541 |
+
self.y.append(y[i*3+2])
|
1542 |
+
|
1543 |
+
def read_img(self, path):
|
1544 |
+
img = cv2.imread(path)
|
1545 |
+
|
1546 |
+
img = cv2.resize(img, (512, 512))
|
1547 |
+
img = (img / 255.0) * 2 - 1
|
1548 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1549 |
+
return img
|
1550 |
+
|
1551 |
+
def __getitem__(self, index):
|
1552 |
+
"""Returns one data pair (source and target)."""
|
1553 |
+
# seq_len, fea_dim
|
1554 |
+
random_number = random.uniform(0, 1)
|
1555 |
+
if random_number<0.5:
|
1556 |
+
non_hair_root = self.non_hair_path2
|
1557 |
+
img_path = self.img_path2
|
1558 |
+
else:
|
1559 |
+
non_hair_root = self.non_hair_path
|
1560 |
+
img_path = self.img_path
|
1561 |
+
random_number1 = random.randrange(0, 21)
|
1562 |
+
random_number2 = random.randrange(0, 21)
|
1563 |
+
|
1564 |
+
name = self.lists[index]
|
1565 |
+
x_stack = []
|
1566 |
+
y_stack = []
|
1567 |
+
img_non_hair_stack = []
|
1568 |
+
img_hair_stack = []
|
1569 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1570 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1571 |
+
|
1572 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1573 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1574 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1575 |
+
ref_hair = self.read_img(ref_path)
|
1576 |
+
non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
|
1577 |
+
img_non_hair = self.read_img(non_hair_path)
|
1578 |
+
hair_path = os.path.join(img_path, name, str(random_number1) + '.jpg')
|
1579 |
+
img_hair= self.read_img(hair_path)
|
1580 |
+
x = self.x[random_number1]
|
1581 |
+
y = self.y[random_number1]
|
1582 |
+
|
1583 |
+
|
1584 |
+
|
1585 |
+
return {
|
1586 |
+
'img_hair': img_hair,
|
1587 |
+
'img_non_hair': img_non_hair,
|
1588 |
+
'ref_hair': ref_hair,
|
1589 |
+
'x': x,
|
1590 |
+
'y': y,
|
1591 |
+
|
1592 |
+
}
|
1593 |
+
|
1594 |
+
def __len__(self):
|
1595 |
+
return self.len
|
1596 |
+
|
1597 |
+
class myDataset_sv3d_simple_temporal_pose(data.Dataset):
|
1598 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1599 |
+
|
1600 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1601 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1602 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1603 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1604 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1605 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1606 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1607 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
1608 |
+
self.lists = os.listdir(self.img_path)
|
1609 |
+
self.len = len(self.lists)-10
|
1610 |
+
self.pose = np.load(self.pose_path)
|
1611 |
+
self.frame_num = frame_num
|
1612 |
+
#self.pose = np.random.randn(12, 4)
|
1613 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1614 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1615 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1616 |
+
for i in Face_yaws:
|
1617 |
+
if i<0:
|
1618 |
+
i = 2*np.pi+i
|
1619 |
+
i = i/2*np.pi*360
|
1620 |
+
face_yaws = [Face_yaws[0]]
|
1621 |
+
for i in range(20):
|
1622 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1623 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1624 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1625 |
+
self.azimuths_rad[:-1].sort()
|
1626 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1627 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1628 |
+
self.x = [x[0]]
|
1629 |
+
self.y = [y[0]]
|
1630 |
+
for i in range(20):
|
1631 |
+
self.x.append(x[i*3+2])
|
1632 |
+
self.y.append(y[i*3+2])
|
1633 |
+
|
1634 |
+
def read_img(self, path):
|
1635 |
+
img = cv2.imread(path)
|
1636 |
+
|
1637 |
+
img = cv2.resize(img, (512, 512))
|
1638 |
+
img = (img / 255.0) * 2 - 1
|
1639 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1640 |
+
return img
|
1641 |
+
|
1642 |
+
def __getitem__(self, index):
|
1643 |
+
"""Returns one data pair (source and target)."""
|
1644 |
+
# seq_len, fea_dim
|
1645 |
+
random_number1 = random.randrange(0, 21-12)
|
1646 |
+
random_number2 = random.randrange(0, 21)
|
1647 |
+
|
1648 |
+
name = self.lists[index]
|
1649 |
+
x_stack = []
|
1650 |
+
y_stack = []
|
1651 |
+
img_non_hair_stack = []
|
1652 |
+
img_hair_stack = []
|
1653 |
+
random_number = random.randint(0, 1)
|
1654 |
+
if random_number==0:
|
1655 |
+
img_path = self.img_path
|
1656 |
+
else:
|
1657 |
+
img_path = self.non_hair_path
|
1658 |
+
non_hair_path = os.path.join(img_path, name, str(random_number2) + '.jpg')
|
1659 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1660 |
+
|
1661 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1662 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1663 |
+
ref_path = os.path.join(ref_folder,files[0])
|
1664 |
+
ref_hair = self.read_img(ref_path)
|
1665 |
+
|
1666 |
+
|
1667 |
+
for i in range(12):
|
1668 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1669 |
+
hair_path = os.path.join(img_path, name, str(random_number1+i) + '.jpg')
|
1670 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1671 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1672 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1673 |
+
|
1674 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1675 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1676 |
+
x = torch.cat(x_stack, axis=0)
|
1677 |
+
y = torch.cat(y_stack, axis=0)
|
1678 |
+
|
1679 |
+
return {
|
1680 |
+
'img_hair': img_hair,
|
1681 |
+
'img_non_hair': img_non_hair,
|
1682 |
+
'ref_hair': ref_hair,
|
1683 |
+
'x': x,
|
1684 |
+
'y': y,
|
1685 |
+
|
1686 |
+
}
|
1687 |
+
|
1688 |
+
def __len__(self):
|
1689 |
+
return self.len
|
1690 |
+
|
1691 |
+
|
1692 |
+
class myDataset_sv3d_simple_temporal_random_reference(data.Dataset):
|
1693 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1694 |
+
|
1695 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1696 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1697 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1698 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1699 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1700 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1701 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1702 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1703 |
+
self.lists = os.listdir(self.img_path)
|
1704 |
+
self.len = len(self.lists)-10
|
1705 |
+
self.pose = np.load(self.pose_path)
|
1706 |
+
self.frame_num = frame_num
|
1707 |
+
#self.pose = np.random.randn(12, 4)
|
1708 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1709 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1710 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1711 |
+
for i in Face_yaws:
|
1712 |
+
if i<0:
|
1713 |
+
i = 2*np.pi+i
|
1714 |
+
i = i/2*np.pi*360
|
1715 |
+
face_yaws = [Face_yaws[0]]
|
1716 |
+
for i in range(20):
|
1717 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1718 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1719 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1720 |
+
self.azimuths_rad[:-1].sort()
|
1721 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1722 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1723 |
+
self.x = [x[0]]
|
1724 |
+
self.y = [y[0]]
|
1725 |
+
for i in range(20):
|
1726 |
+
self.x.append(x[i*3+2])
|
1727 |
+
self.y.append(y[i*3+2])
|
1728 |
+
|
1729 |
+
def read_img(self, path):
|
1730 |
+
img = cv2.imread(path)
|
1731 |
+
|
1732 |
+
img = cv2.resize(img, (512, 512))
|
1733 |
+
img = (img / 255.0) * 2 - 1
|
1734 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1735 |
+
return img
|
1736 |
+
|
1737 |
+
def __getitem__(self, index):
|
1738 |
+
"""Returns one data pair (source and target)."""
|
1739 |
+
# seq_len, fea_dim
|
1740 |
+
random_number1 = random.randrange(0, 21-12)
|
1741 |
+
random_number2 = random.randrange(0, 21)
|
1742 |
+
|
1743 |
+
name = self.lists[index]
|
1744 |
+
x_stack = []
|
1745 |
+
y_stack = []
|
1746 |
+
img_non_hair_stack = []
|
1747 |
+
img_hair_stack = []
|
1748 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1749 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1750 |
+
|
1751 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
1752 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1753 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1754 |
+
ref_hair = self.read_img(ref_path)
|
1755 |
+
for i in range(12):
|
1756 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
1757 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
1758 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
1759 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
1760 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
1761 |
+
|
1762 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
1763 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
1764 |
+
x = torch.cat(x_stack, axis=0)
|
1765 |
+
y = torch.cat(y_stack, axis=0)
|
1766 |
+
|
1767 |
+
return {
|
1768 |
+
'img_hair': img_hair,
|
1769 |
+
'img_non_hair': img_non_hair,
|
1770 |
+
'ref_hair': ref_hair,
|
1771 |
+
'x': x,
|
1772 |
+
'y': y,
|
1773 |
+
|
1774 |
+
}
|
1775 |
+
|
1776 |
+
def __len__(self):
|
1777 |
+
return self.len
|
1778 |
+
|
1779 |
+
class myDataset_sv3d_simple_random_reference(data.Dataset):
|
1780 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1781 |
+
|
1782 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1783 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1784 |
+
self.img_path2 = os.path.join(train_data_dir, "hair_good")
|
1785 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1786 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
1787 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
1788 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1789 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1790 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1791 |
+
# self.lists = os.listdir(self.img_path2)
|
1792 |
+
self.lists = os.listdir(self.img_path)
|
1793 |
+
self.len = len(self.lists)-10
|
1794 |
+
self.pose = np.load(self.pose_path)
|
1795 |
+
self.frame_num = frame_num
|
1796 |
+
#self.pose = np.random.randn(12, 4)
|
1797 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1798 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1799 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1800 |
+
for i in Face_yaws:
|
1801 |
+
if i<0:
|
1802 |
+
i = 2*np.pi+i
|
1803 |
+
i = i/2*np.pi*360
|
1804 |
+
face_yaws = [Face_yaws[0]]
|
1805 |
+
for i in range(20):
|
1806 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1807 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1808 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1809 |
+
self.azimuths_rad[:-1].sort()
|
1810 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1811 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1812 |
+
self.x = [x[0]]
|
1813 |
+
self.y = [y[0]]
|
1814 |
+
for i in range(20):
|
1815 |
+
self.x.append(x[i*3+2])
|
1816 |
+
self.y.append(y[i*3+2])
|
1817 |
+
|
1818 |
+
def __getitem__(self, index):
|
1819 |
+
"""Returns one data pair (source and target)."""
|
1820 |
+
# seq_len, fea_dim
|
1821 |
+
random_number1 = random.randrange(0, 21)
|
1822 |
+
random_number2 = random.randrange(0, 21)
|
1823 |
+
|
1824 |
+
# while random_number2 == random_number1:
|
1825 |
+
# random_number2 = random.randrange(0, 21)
|
1826 |
+
name = self.lists[index].split('.')[0]
|
1827 |
+
|
1828 |
+
#random_number1 = random_number1
|
1829 |
+
#* 10
|
1830 |
+
#random_number2 = random_number2 * 10
|
1831 |
+
|
1832 |
+
#random_number2 = random_number1
|
1833 |
+
|
1834 |
+
|
1835 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
1836 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1837 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1838 |
+
|
1839 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
|
1840 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
1841 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1842 |
+
# print('________')
|
1843 |
+
# print(files)
|
1844 |
+
# print('++++++++')
|
1845 |
+
# print(ref_folder)
|
1846 |
+
# print("========")
|
1847 |
+
# print(name)
|
1848 |
+
# print("********")
|
1849 |
+
# print(ref_path)
|
1850 |
+
img_hair = cv2.imread(hair_path)
|
1851 |
+
img_non_hair = cv2.imread(non_hair_path)
|
1852 |
+
ref_hair = cv2.imread(ref_path)
|
1853 |
+
|
1854 |
+
|
1855 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
1856 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
1857 |
+
|
1858 |
+
|
1859 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
1860 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
1861 |
+
|
1862 |
+
|
1863 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
1864 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
1865 |
+
|
1866 |
+
|
1867 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
1868 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
1869 |
+
|
1870 |
+
pose = self.pose[random_number1]
|
1871 |
+
pose = torch.tensor(pose)
|
1872 |
+
# pose2 = self.pose[random_number2]
|
1873 |
+
# pose2 = torch.tensor(pose2)
|
1874 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
1875 |
+
# hair_num = [0, 2, 6, 14, 18, 21]
|
1876 |
+
# img_hair_stack = []
|
1877 |
+
# polar = self.polars_rad[random_number1]
|
1878 |
+
# polar = torch.tensor(polar).unsqueeze(0)
|
1879 |
+
# azimuths = self.azimuths_rad[random_number1]
|
1880 |
+
# azimuths = torch.tensor(azimuths).unsqueeze(0)
|
1881 |
+
img_hair = cv2.imread(hair_path)
|
1882 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
1883 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
1884 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
1885 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
1886 |
+
x = torch.tensor(self.x[random_number1])
|
1887 |
+
y = torch.tensor(self.y[random_number1])
|
1888 |
+
x2 = torch.tensor(self.x[random_number2])
|
1889 |
+
y2 = torch.tensor(self.y[random_number2])
|
1890 |
+
# begin = random.randrange(0, 21-self.frame_num)
|
1891 |
+
# hair_num = [i+begin for i in range(self.frame_num)]
|
1892 |
+
# for i in hair_num:
|
1893 |
+
# img_hair = cv2.imread(hair_path)
|
1894 |
+
# img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
1895 |
+
# img_hair = cv2.resize(img_hair, (512, 512))
|
1896 |
+
# img_hair = (img_hair / 255.0) * 2 - 1
|
1897 |
+
# img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
1898 |
+
# img_hair_stack.append(img_hair)
|
1899 |
+
# img_hair = torch.stack(img_hair_stack)
|
1900 |
+
|
1901 |
+
return {
|
1902 |
+
# 'hair_pose': pose1,
|
1903 |
+
'img_hair': img_hair,
|
1904 |
+
# 'bald_pose': pose2,
|
1905 |
+
'pose': pose,
|
1906 |
+
'img_non_hair': img_non_hair,
|
1907 |
+
'ref_hair': ref_hair,
|
1908 |
+
'x': x,
|
1909 |
+
'y': y,
|
1910 |
+
# 'x2': x2,
|
1911 |
+
# 'y2': y2,
|
1912 |
+
# 'polar': polar,
|
1913 |
+
# 'azimuths':azimuths,
|
1914 |
+
}
|
1915 |
+
|
1916 |
+
def __len__(self):
|
1917 |
+
return self.len
|
1918 |
+
|
1919 |
+
class myDataset_sv3d_simple_random_reference_controlnet(data.Dataset):
|
1920 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
1921 |
+
|
1922 |
+
def __init__(self, train_data_dir, frame_num=6):
|
1923 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
1924 |
+
self.img_path2 = os.path.join(train_data_dir, "hair_good")
|
1925 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
1926 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
1927 |
+
self.ref_path = os.path.join(train_data_dir, "multi_reference2")
|
1928 |
+
# self.ref_path = os.path.join(train_data_dir, "reference")
|
1929 |
+
self.lists = os.listdir(self.img_path)
|
1930 |
+
self.len = len(self.lists)-10
|
1931 |
+
self.pose = np.load(self.pose_path)
|
1932 |
+
self.frame_num = frame_num
|
1933 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
1934 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
1935 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
1936 |
+
for i in Face_yaws:
|
1937 |
+
if i<0:
|
1938 |
+
i = 2*np.pi+i
|
1939 |
+
i = i/2*np.pi*360
|
1940 |
+
face_yaws = [Face_yaws[0]]
|
1941 |
+
for i in range(20):
|
1942 |
+
face_yaws.append(Face_yaws[3*i+2])
|
1943 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
1944 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
1945 |
+
self.azimuths_rad[:-1].sort()
|
1946 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
1947 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
1948 |
+
self.x = [x[0]]
|
1949 |
+
self.y = [y[0]]
|
1950 |
+
for i in range(20):
|
1951 |
+
self.x.append(x[i*3+2])
|
1952 |
+
self.y.append(y[i*3+2])
|
1953 |
+
|
1954 |
+
def read_ref_img(self, path):
|
1955 |
+
img = cv2.imread(path)
|
1956 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
1957 |
+
img = hair_transform(image=img)
|
1958 |
+
# img = cv2.resize(img, (512, 512))
|
1959 |
+
img = (img / 255.0) * 2 - 1
|
1960 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
1961 |
+
return img
|
1962 |
+
|
1963 |
+
def __getitem__(self, index):
|
1964 |
+
"""Returns one data pair (source and target)."""
|
1965 |
+
# seq_len, fea_dim
|
1966 |
+
random_number1 = random.randrange(0, 21)
|
1967 |
+
random_number2 = random.randrange(0, 21)
|
1968 |
+
name = self.lists[index].split('.')[0]
|
1969 |
+
|
1970 |
+
# hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
1971 |
+
hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
1972 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
1973 |
+
ref_folder = os.path.join(self.ref_path, name)
|
1974 |
+
|
1975 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')][:3]
|
1976 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
1977 |
+
img_hair = cv2.imread(hair_path)
|
1978 |
+
img_non_hair = cv2.imread(non_hair_path)
|
1979 |
+
ref_hair = cv2.imread(ref_path)
|
1980 |
+
|
1981 |
+
|
1982 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
1983 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
1984 |
+
|
1985 |
+
ref_hair = hair_transform(image=ref_hair)['image']
|
1986 |
+
# print(type(ref_hair))
|
1987 |
+
# print(ref_hair.keys())
|
1988 |
+
# ref_hair = self.read_ref_img(ref_path)
|
1989 |
+
|
1990 |
+
|
1991 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
1992 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
1993 |
+
|
1994 |
+
|
1995 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
1996 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
1997 |
+
|
1998 |
+
|
1999 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
2000 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
2001 |
+
|
2002 |
+
pose = self.pose[random_number1]
|
2003 |
+
pose = torch.tensor(pose)
|
2004 |
+
#hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
2005 |
+
# hair_path = os.path.join(self.non_hair_path, name, str(random_number1) + '.jpg')
|
2006 |
+
img_hair = cv2.imread(hair_path)
|
2007 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
2008 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
2009 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
2010 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
2011 |
+
x = torch.tensor(self.x[random_number1])
|
2012 |
+
y = torch.tensor(self.y[random_number1])
|
2013 |
+
x2 = torch.tensor(self.x[random_number2])
|
2014 |
+
y2 = torch.tensor(self.y[random_number2])
|
2015 |
+
|
2016 |
+
return {
|
2017 |
+
# 'hair_pose': pose1,
|
2018 |
+
'img_hair': img_hair,
|
2019 |
+
# 'bald_pose': pose2,
|
2020 |
+
'pose': pose,
|
2021 |
+
'img_non_hair': img_non_hair,
|
2022 |
+
'ref_hair': ref_hair,
|
2023 |
+
'x': x,
|
2024 |
+
'y': y,
|
2025 |
+
}
|
2026 |
+
|
2027 |
+
def __len__(self):
|
2028 |
+
return self.len
|
2029 |
+
|
2030 |
+
|
2031 |
+
class myDataset_sv3d_simple_random_reference_stable_hair(data.Dataset):
|
2032 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
2033 |
+
|
2034 |
+
def __init__(self, train_data_dir, frame_num=6):
|
2035 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
2036 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
2037 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
2038 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
2039 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
2040 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
2041 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
2042 |
+
self.lists = os.listdir(self.img_path)
|
2043 |
+
self.len = len(self.lists)-10
|
2044 |
+
self.pose = np.load(self.pose_path)
|
2045 |
+
self.frame_num = frame_num
|
2046 |
+
#self.pose = np.random.randn(12, 4)
|
2047 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
2048 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
2049 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
2050 |
+
for i in Face_yaws:
|
2051 |
+
if i<0:
|
2052 |
+
i = 2*np.pi+i
|
2053 |
+
i = i/2*np.pi*360
|
2054 |
+
face_yaws = [Face_yaws[0]]
|
2055 |
+
for i in range(20):
|
2056 |
+
face_yaws.append(Face_yaws[3*i+2])
|
2057 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
2058 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
2059 |
+
self.azimuths_rad[:-1].sort()
|
2060 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
2061 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
2062 |
+
self.x = [x[0]]
|
2063 |
+
self.y = [y[0]]
|
2064 |
+
for i in range(20):
|
2065 |
+
self.x.append(x[i*3+2])
|
2066 |
+
self.y.append(y[i*3+2])
|
2067 |
+
|
2068 |
+
def __getitem__(self, index):
|
2069 |
+
"""Returns one data pair (source and target)."""
|
2070 |
+
# seq_len, fea_dim
|
2071 |
+
random_number1 = random.randrange(0, 21)
|
2072 |
+
random_number2 = random.randrange(0, 21)
|
2073 |
+
random_number1 = random_number2
|
2074 |
+
|
2075 |
+
name = self.lists[index]
|
2076 |
+
|
2077 |
+
|
2078 |
+
|
2079 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
2080 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
2081 |
+
ref_folder = os.path.join(self.ref_path, name)
|
2082 |
+
|
2083 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
2084 |
+
ref_path = os.path.join(ref_folder,random.choice(files))
|
2085 |
+
img_hair = cv2.imread(hair_path)
|
2086 |
+
img_non_hair = cv2.imread(non_hair_path)
|
2087 |
+
ref_hair = cv2.imread(ref_path)
|
2088 |
+
|
2089 |
+
|
2090 |
+
img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB)
|
2091 |
+
ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB)
|
2092 |
+
|
2093 |
+
|
2094 |
+
img_non_hair = cv2.resize(img_non_hair, (512, 512))
|
2095 |
+
ref_hair = cv2.resize(ref_hair, (512, 512))
|
2096 |
+
|
2097 |
+
|
2098 |
+
img_non_hair = (img_non_hair / 255.0) * 2 - 1
|
2099 |
+
ref_hair = (ref_hair / 255.0) * 2 - 1
|
2100 |
+
|
2101 |
+
|
2102 |
+
img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1)
|
2103 |
+
ref_hair = torch.tensor(ref_hair).permute(2, 0, 1)
|
2104 |
+
|
2105 |
+
pose = self.pose[random_number1]
|
2106 |
+
pose = torch.tensor(pose)
|
2107 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg')
|
2108 |
+
img_hair = cv2.imread(hair_path)
|
2109 |
+
img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB)
|
2110 |
+
img_hair = cv2.resize(img_hair, (512, 512))
|
2111 |
+
img_hair = (img_hair / 255.0) * 2 - 1
|
2112 |
+
img_hair = torch.tensor(img_hair).permute(2, 0, 1)
|
2113 |
+
x = torch.tensor(self.x[random_number1])
|
2114 |
+
y = torch.tensor(self.y[random_number1])
|
2115 |
+
|
2116 |
+
return {
|
2117 |
+
'img_hair': img_hair,
|
2118 |
+
'pose': pose,
|
2119 |
+
'img_non_hair': img_non_hair,
|
2120 |
+
'ref_hair': ref_hair,
|
2121 |
+
'x': x,
|
2122 |
+
'y': y,
|
2123 |
+
}
|
2124 |
+
|
2125 |
+
def __len__(self):
|
2126 |
+
return self.len
|
2127 |
+
|
2128 |
+
|
2129 |
+
class myDataset_sv3d_simple_temporal_small_squence(data.Dataset):
|
2130 |
+
"""Custom data.Dataset compatible with data.DataLoader."""
|
2131 |
+
|
2132 |
+
def __init__(self, train_data_dir, frame_num=6):
|
2133 |
+
self.img_path = os.path.join(train_data_dir, "hair")
|
2134 |
+
# self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
2135 |
+
# self.non_hair_path = os.path.join(train_data_dir, "no_hair")
|
2136 |
+
# self.ref_path = os.path.join(train_data_dir, "ref_hair")
|
2137 |
+
self.pose_path = os.path.join(train_data_dir, "pose.npy")
|
2138 |
+
self.non_hair_path = os.path.join(train_data_dir, "non-hair")
|
2139 |
+
self.ref_path = os.path.join(train_data_dir, "reference")
|
2140 |
+
self.lists = os.listdir(self.img_path)
|
2141 |
+
self.len = len(self.lists)-10
|
2142 |
+
self.pose = np.load(self.pose_path)
|
2143 |
+
self.frame_num = frame_num
|
2144 |
+
#self.pose = np.random.randn(12, 4)
|
2145 |
+
elevations_deg = [-0.05/2*np.pi*360] * 21
|
2146 |
+
azimuths_deg = np.linspace(0, 360, 21+1)[1:] % 360
|
2147 |
+
Face_yaws = [0.4 * np.sin(2 * 3.14 * i / 60) for i in range(60)]
|
2148 |
+
for i in Face_yaws:
|
2149 |
+
if i<0:
|
2150 |
+
i = 2*np.pi+i
|
2151 |
+
i = i/2*np.pi*360
|
2152 |
+
face_yaws = [Face_yaws[0]]
|
2153 |
+
for i in range(20):
|
2154 |
+
face_yaws.append(Face_yaws[3*i+2])
|
2155 |
+
self.polars_rad = [np.deg2rad(90-e) for e in elevations_deg]
|
2156 |
+
self.azimuths_rad = [np.deg2rad((a) % 360) for a in azimuths_deg]
|
2157 |
+
self.azimuths_rad[:-1].sort()
|
2158 |
+
x = [0.4 * np.sin(2 * 3.14 * i / 120) for i in range(60)]
|
2159 |
+
y = [- 0.05 + 0.3 * np.cos(2 * 3.14 * i / 120) for i in range(60)]
|
2160 |
+
self.x = [x[0]]
|
2161 |
+
self.y = [y[0]]
|
2162 |
+
for i in range(20):
|
2163 |
+
self.x.append(x[i*3+2])
|
2164 |
+
self.y.append(y[i*3+2])
|
2165 |
+
|
2166 |
+
def read_img(self, path):
|
2167 |
+
img = cv2.imread(path)
|
2168 |
+
|
2169 |
+
img = cv2.resize(img, (512, 512))
|
2170 |
+
img = (img / 255.0) * 2 - 1
|
2171 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
2172 |
+
return img
|
2173 |
+
|
2174 |
+
def __getitem__(self, index):
|
2175 |
+
"""Returns one data pair (source and target)."""
|
2176 |
+
# seq_len, fea_dim
|
2177 |
+
random_number1 = random.randrange(0, 21-6)
|
2178 |
+
random_number2 = random.randrange(0, 21)
|
2179 |
+
|
2180 |
+
name = self.lists[index]
|
2181 |
+
x_stack = []
|
2182 |
+
y_stack = []
|
2183 |
+
img_non_hair_stack = []
|
2184 |
+
img_hair_stack = []
|
2185 |
+
non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg')
|
2186 |
+
ref_folder = os.path.join(self.ref_path, name)
|
2187 |
+
|
2188 |
+
files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')]
|
2189 |
+
# ref_path = os.path.join(ref_folder, str(random_number2) + '.jpg')
|
2190 |
+
ref_path = os.path.join(ref_folder,files[0])
|
2191 |
+
ref_hair = self.read_img(ref_path)
|
2192 |
+
for i in range(6):
|
2193 |
+
img_non_hair_stack.append(self.read_img(non_hair_path).unsqueeze(0))
|
2194 |
+
hair_path = os.path.join(self.img_path, name, str(random_number1+i) + '.jpg')
|
2195 |
+
img_hair_stack.append(self.read_img(hair_path).unsqueeze(0))
|
2196 |
+
x_stack.append(torch.tensor(self.x[random_number1+i]).unsqueeze(0))
|
2197 |
+
y_stack.append(torch.tensor(self.y[random_number1+i]).unsqueeze(0))
|
2198 |
+
|
2199 |
+
img_non_hair = torch.cat(img_non_hair_stack, axis=0)
|
2200 |
+
img_hair = torch.cat(img_hair_stack, axis=0)
|
2201 |
+
x = torch.cat(x_stack, axis=0)
|
2202 |
+
y = torch.cat(y_stack, axis=0)
|
2203 |
+
|
2204 |
+
return {
|
2205 |
+
'img_hair': img_hair,
|
2206 |
+
'img_non_hair': img_non_hair,
|
2207 |
+
'ref_hair': ref_hair,
|
2208 |
+
'x': x,
|
2209 |
+
'y': y,
|
2210 |
+
|
2211 |
+
}
|
2212 |
+
|
2213 |
+
def __len__(self):
|
2214 |
+
return self.len
|
2215 |
+
|
2216 |
+
|
2217 |
+
if __name__ == "__main__":
|
2218 |
+
|
2219 |
+
train_dataset = myDataset("./data")
|
2220 |
+
train_dataloader = torch.utils.data.DataLoader(
|
2221 |
+
train_dataset,
|
2222 |
+
batch_size=1,
|
2223 |
+
num_workers=1,
|
2224 |
+
)
|
2225 |
+
|
2226 |
+
for epoch in range(0, len(train_dataset) + 1):
|
2227 |
+
for step, batch in enumerate(train_dataloader):
|
2228 |
+
print("batch[hair_pose]:", batch["hair_pose"])
|
2229 |
+
print("batch[img_hair]:", batch["img_hair"])
|
2230 |
+
print("batch[bald_pose]:", batch["bald_pose"])
|
2231 |
+
print("batch[img_non_hair]:", batch["img_non_hair"])
|
2232 |
+
print("batch[ref_hair]:", batch["ref_hair"])
|
2233 |
+
|
2234 |
+
|
2235 |
+
|
2236 |
+
|
download.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
snapshot_download(
|
3 |
+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
4 |
+
local_dir="stable-diffusion-v1-5/stable-diffusion-v1-5",
|
5 |
+
local_dir_use_symlinks=False,
|
6 |
+
resume_download=True,
|
7 |
+
allow_patterns=["unet/*","vae/*","tokenizer/*","scheduler/*","model_index.json"]
|
8 |
+
)
|
gradio_app.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ.setdefault("GRADIO_TEMP_DIR", "/data2/lzliu/tmp/gradio")
|
3 |
+
os.environ.setdefault("TMPDIR", "/data2/lzliu/tmp")
|
4 |
+
os.makedirs("/data2/lzliu/tmp/gradio", exist_ok=True)
|
5 |
+
os.makedirs("/data2/lzliu/tmp", exist_ok=True)
|
6 |
+
|
7 |
+
|
8 |
+
# 其余保持不变
|
9 |
+
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import gradio as gr
|
13 |
+
import torch
|
14 |
+
import os
|
15 |
+
import uuid
|
16 |
+
from test_stablehairv2 import log_validation
|
17 |
+
from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
|
18 |
+
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
import numpy as np
|
21 |
+
import cv2
|
22 |
+
from test_stablehairv2 import _maybe_align_image
|
23 |
+
from HairMapper.hair_mapper_run import bald_head
|
24 |
+
|
25 |
+
import base64
|
26 |
+
|
27 |
+
with open("imgs/background.jpg", "rb") as f:
|
28 |
+
b64_img = base64.b64encode(f.read()).decode()
|
29 |
+
|
30 |
+
|
31 |
+
def inference(id_image, hair_image):
|
32 |
+
os.makedirs("gradio_inputs", exist_ok=True)
|
33 |
+
os.makedirs("gradio_outputs", exist_ok=True)
|
34 |
+
|
35 |
+
id_path = "gradio_inputs/id.png"
|
36 |
+
hair_path = "gradio_inputs/hair.png"
|
37 |
+
id_image.save(id_path)
|
38 |
+
hair_image.save(hair_path)
|
39 |
+
|
40 |
+
# ===== 图像对齐 =====
|
41 |
+
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
|
42 |
+
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
|
43 |
+
|
44 |
+
# 保存对齐结果(方便 Gradio 输出)
|
45 |
+
aligned_id_path = "gradio_outputs/aligned_id.png"
|
46 |
+
aligned_hair_path = "gradio_outputs/aligned_hair.png"
|
47 |
+
cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
48 |
+
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
|
49 |
+
|
50 |
+
# ===== 调用 HairMapper 秃头化 =====
|
51 |
+
bald_id_path = "gradio_outputs/bald_id.png"
|
52 |
+
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
53 |
+
bald_head(bald_id_path, bald_id_path)
|
54 |
+
|
55 |
+
# ===== 原本的 Args =====
|
56 |
+
class Args:
|
57 |
+
pretrained_model_name_or_path = "./stable-diffusion-v1-5/stable-diffusion-v1-5"
|
58 |
+
model_path = "./trained_model"
|
59 |
+
image_encoder = "openai/clip-vit-large-patch14"
|
60 |
+
controlnet_model_name_or_path = None
|
61 |
+
revision = None
|
62 |
+
output_dir = "gradio_outputs"
|
63 |
+
seed = 42
|
64 |
+
num_validation_images = 1
|
65 |
+
validation_ids = [aligned_id_path] # 用对齐后的图像
|
66 |
+
validation_hairs = [aligned_hair_path] # 用对齐后的图像
|
67 |
+
use_fp16 = False
|
68 |
+
|
69 |
+
args = Args()
|
70 |
+
|
71 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
72 |
+
|
73 |
+
# 初始化 logger
|
74 |
+
logging.basicConfig(
|
75 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
76 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
77 |
+
level=logging.INFO,
|
78 |
+
)
|
79 |
+
logger = logging.getLogger(__name__)
|
80 |
+
|
81 |
+
# ===== 模型加载(和 main() 对齐) =====
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
|
83 |
+
revision=args.revision)
|
84 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
|
85 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(
|
86 |
+
device, dtype=torch.float32)
|
87 |
+
|
88 |
+
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
89 |
+
|
90 |
+
unet2 = UNet2DConditionModel.from_pretrained(
|
91 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
|
92 |
+
).to(device)
|
93 |
+
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
|
94 |
+
padding=unet2.conv_in.padding)
|
95 |
+
conv_in_8.requires_grad_(False)
|
96 |
+
unet2.conv_in.requires_grad_(False)
|
97 |
+
torch.nn.init.zeros_(conv_in_8.weight)
|
98 |
+
conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
|
99 |
+
conv_in_8.bias.copy_(unet2.conv_in.bias)
|
100 |
+
unet2.conv_in = conv_in_8
|
101 |
+
|
102 |
+
controlnet = ControlNetModel.from_unet(unet2).to(device)
|
103 |
+
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu")
|
104 |
+
controlnet.load_state_dict(state_dict2, strict=False)
|
105 |
+
|
106 |
+
prefix = "motion_module"
|
107 |
+
ckpt_num = "4140000"
|
108 |
+
save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
|
109 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
110 |
+
args.pretrained_model_name_or_path,
|
111 |
+
save_path,
|
112 |
+
subfolder="unet",
|
113 |
+
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
114 |
+
).to(device)
|
115 |
+
|
116 |
+
cc_projection = CCProjection().to(device)
|
117 |
+
state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu")
|
118 |
+
cc_projection.load_state_dict(state_dict3, strict=False)
|
119 |
+
|
120 |
+
from ref_encoder.reference_unet import ref_unet
|
121 |
+
Hair_Encoder = ref_unet.from_pretrained(
|
122 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
|
123 |
+
device_map=None, ignore_mismatched_sizes=True
|
124 |
+
).to(device)
|
125 |
+
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
126 |
+
Hair_Encoder.load_state_dict(state_dict2, strict=False)
|
127 |
+
|
128 |
+
# 推理
|
129 |
+
log_validation(
|
130 |
+
vae, tokenizer, image_encoder, denoising_unet,
|
131 |
+
args, device, logger,
|
132 |
+
cc_projection, controlnet, Hair_Encoder
|
133 |
+
)
|
134 |
+
|
135 |
+
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
136 |
+
|
137 |
+
# 提取视频帧用于可拖动预览
|
138 |
+
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
|
139 |
+
os.makedirs(frames_dir, exist_ok=True)
|
140 |
+
cap = cv2.VideoCapture(output_video)
|
141 |
+
frames_list = []
|
142 |
+
idx = 0
|
143 |
+
while True:
|
144 |
+
ret, frame = cap.read()
|
145 |
+
if not ret:
|
146 |
+
break
|
147 |
+
fp = os.path.join(frames_dir, f"{idx:03d}.png")
|
148 |
+
cv2.imwrite(fp, frame)
|
149 |
+
frames_list.append(fp)
|
150 |
+
idx += 1
|
151 |
+
cap.release()
|
152 |
+
|
153 |
+
max_frames = len(frames_list) if frames_list else 1
|
154 |
+
first_frame = frames_list[0] if frames_list else None
|
155 |
+
|
156 |
+
return aligned_id_path, aligned_hair_path, bald_id_path, output_video, frames_list, gr.update(minimum=1,
|
157 |
+
maximum=max_frames,
|
158 |
+
value=1,
|
159 |
+
step=1), first_frame
|
160 |
+
|
161 |
+
|
162 |
+
# Gradio 前端
|
163 |
+
# 原 Interface 版本(保留以便回退)
|
164 |
+
# demo = gr.Interface(
|
165 |
+
# fn=inference,
|
166 |
+
# inputs=[
|
167 |
+
# gr.Image(type="pil", label="上传身份图(ID Image)"),
|
168 |
+
# gr.Image(type="pil", label="上传发型图(Hair Reference Image)")
|
169 |
+
# ],
|
170 |
+
# outputs=[
|
171 |
+
# gr.Image(type="filepath", label="对齐后的身份图"),
|
172 |
+
# gr.Image(type="filepath", label="对齐后的发型图"),
|
173 |
+
# gr.Image(type="filepath", label="秃头化后的身份图"),
|
174 |
+
# gr.Video(label="生成的视频")
|
175 |
+
# ],
|
176 |
+
# title="StableHairV2 多视角发型迁移",
|
177 |
+
# description="上传身份图和发型参考图,查看对齐结果并生成多视角视频"
|
178 |
+
# )
|
179 |
+
# if __name__ == "__main__":
|
180 |
+
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
181 |
+
|
182 |
+
# Blocks 美化版
|
183 |
+
css = f"""
|
184 |
+
html, body {{
|
185 |
+
height: 100%;
|
186 |
+
margin: 0;
|
187 |
+
padding: 0;
|
188 |
+
}}
|
189 |
+
.gradio-container {{
|
190 |
+
width: 100% !important;
|
191 |
+
height: 100% !important;
|
192 |
+
margin: 0 !important;
|
193 |
+
padding: 0 !important;
|
194 |
+
background-image: url("data:image/jpeg;base64,{b64_img}");
|
195 |
+
background-size: cover;
|
196 |
+
background-position: center;
|
197 |
+
background-attachment: fixed; /* 背景固定 */
|
198 |
+
}}
|
199 |
+
#title-card {{
|
200 |
+
background: rgba(255, 255, 255, 0.8);
|
201 |
+
border-radius: 12px;
|
202 |
+
padding: 16px 24px;
|
203 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
|
204 |
+
margin-bottom: 20px;
|
205 |
+
}}
|
206 |
+
#title-card h2 {{
|
207 |
+
text-align: center;
|
208 |
+
margin: 4px 0 12px 0;
|
209 |
+
font-size: 28px;
|
210 |
+
}}
|
211 |
+
#title-card p {{
|
212 |
+
text-align: center;
|
213 |
+
font-size: 16px;
|
214 |
+
color: #374151;
|
215 |
+
}}
|
216 |
+
.out-card {{
|
217 |
+
border:1px solid #e5e7eb; border-radius:10px; padding:10px;
|
218 |
+
background: rgba(255,255,255,0.85);
|
219 |
+
}}
|
220 |
+
.two-col {{
|
221 |
+
display:grid !important; grid-template-columns: 360px minmax(680px, 1fr); gap:16px
|
222 |
+
}}
|
223 |
+
.left-pane {{min-width: 360px}}
|
224 |
+
.right-pane {{min-width: 680px}}
|
225 |
+
/* Tabs 美化 */
|
226 |
+
.tabs {{
|
227 |
+
background: rgba(255,255,255,0.88);
|
228 |
+
border-radius: 12px;
|
229 |
+
box-shadow: 0 8px 24px rgba(0,0,0,0.08);
|
230 |
+
padding: 8px;
|
231 |
+
border: 1px solid #e5e7eb;
|
232 |
+
}}
|
233 |
+
.tab-nav {{
|
234 |
+
display: flex; gap: 8px; margin-bottom: 8px;
|
235 |
+
background: transparent;
|
236 |
+
border-bottom: 1px solid #e5e7eb;
|
237 |
+
padding-bottom: 6px;
|
238 |
+
}}
|
239 |
+
.tab-nav button {{
|
240 |
+
background: rgba(255,255,255,0.7);
|
241 |
+
border: 1px solid #e5e7eb;
|
242 |
+
backdrop-filter: blur(6px);
|
243 |
+
border-radius: 8px;
|
244 |
+
padding: 6px 12px;
|
245 |
+
color: #111827;
|
246 |
+
transition: all .2s ease;
|
247 |
+
}}
|
248 |
+
.tab-nav button:hover {{
|
249 |
+
transform: translateY(-1px);
|
250 |
+
box-shadow: 0 4px 10px rgba(0,0,0,0.06);
|
251 |
+
}}
|
252 |
+
.tab-nav button[aria-selected="true"] {{
|
253 |
+
background: #4f46e5;
|
254 |
+
color: #fff;
|
255 |
+
border-color: #4f46e5;
|
256 |
+
box-shadow: 0 6px 14px rgba(79,70,229,0.25);
|
257 |
+
}}
|
258 |
+
.tabitem {{
|
259 |
+
background: rgba(255,255,255,0.88);
|
260 |
+
border-radius: 10px;
|
261 |
+
padding: 8px;
|
262 |
+
}}
|
263 |
+
/* 发型库滚动限制容器:固定260px高度,内部可滚动 */
|
264 |
+
#hair_gallery_wrap {{
|
265 |
+
height: 260px !important;
|
266 |
+
overflow-y: scroll !important;
|
267 |
+
overflow-x: auto !important;
|
268 |
+
}}
|
269 |
+
#hair_gallery_wrap .grid, #hair_gallery_wrap .wrap {{
|
270 |
+
height: 100% !important;
|
271 |
+
overflow-y: scroll !important;
|
272 |
+
}}
|
273 |
+
/* 确保画廊本体占满容���高度,避免滚动条落到页面底部 */
|
274 |
+
#hair_gallery {{
|
275 |
+
height: 100% !important;
|
276 |
+
}}
|
277 |
+
"""
|
278 |
+
|
279 |
+
with gr.Blocks(
|
280 |
+
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
|
281 |
+
css=css
|
282 |
+
) as demo:
|
283 |
+
# ==== 顶部 Panel ====
|
284 |
+
with gr.Group(elem_id="title-card"):
|
285 |
+
gr.Markdown("""
|
286 |
+
<h2 id='title'>StableHairV2 多视角发型迁移</h2>
|
287 |
+
<p>上传身份图与发型参考图,系统将自动完成 <b>对齐 → 秃头化 → 视频生成</b>。</p>
|
288 |
+
""")
|
289 |
+
|
290 |
+
with gr.Row(elem_classes=["two-col"]):
|
291 |
+
with gr.Column(scale=5, min_width=260, elem_classes=["left-pane"]):
|
292 |
+
id_input = gr.Image(type="pil", label="身份图", height=200)
|
293 |
+
hair_input = gr.Image(type="pil", label="发型参考图", height=200)
|
294 |
+
|
295 |
+
with gr.Row():
|
296 |
+
run_btn = gr.Button("开始生成", variant="primary")
|
297 |
+
clear_btn = gr.Button("清空")
|
298 |
+
|
299 |
+
# ========= 发型库(点击即填充到“发型参考图”) =========
|
300 |
+
def _list_imgs(dir_path: str):
|
301 |
+
exts = (".png", ".jpg", ".jpeg", ".webp")
|
302 |
+
# exts = (".jpg")
|
303 |
+
try:
|
304 |
+
files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path))
|
305 |
+
if f.lower().endswith(exts)]
|
306 |
+
return files
|
307 |
+
except Exception:
|
308 |
+
return []
|
309 |
+
|
310 |
+
hair_list = _list_imgs("hair_resposity")
|
311 |
+
|
312 |
+
with gr.Accordion("发型库(点击选择后自动填充)", open=True):
|
313 |
+
with gr.Group(elem_id="hair_gallery_wrap"):
|
314 |
+
gallery = gr.Gallery(
|
315 |
+
value=hair_list,
|
316 |
+
columns=4, rows=2, allow_preview=True, label="发型库",
|
317 |
+
elem_id="hair_gallery"
|
318 |
+
)
|
319 |
+
|
320 |
+
def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
|
321 |
+
i = evt.index if hasattr(evt, 'index') else 0
|
322 |
+
i = 0 if i is None else int(i)
|
323 |
+
if 0 <= i < len(hair_list):
|
324 |
+
return gr.update(value=hair_list[i])
|
325 |
+
return gr.update()
|
326 |
+
|
327 |
+
gallery.select(_pick_hair, inputs=None, outputs=hair_input)
|
328 |
+
|
329 |
+
with gr.Column(scale=7, min_width=520, elem_classes=["right-pane"]):
|
330 |
+
with gr.Tabs():
|
331 |
+
with gr.TabItem("生成视频"):
|
332 |
+
with gr.Group(elem_classes=["out-card"]):
|
333 |
+
video_out = gr.Video(label="生成的视频", height=340)
|
334 |
+
with gr.Row():
|
335 |
+
frame_slider = gr.Slider(1, 21, value=1, step=1, label="多视角预览(拖动查看帧)")
|
336 |
+
frame_preview = gr.Image(type="filepath", label="预览帧", height=260)
|
337 |
+
frames_state = gr.State([])
|
338 |
+
|
339 |
+
with gr.TabItem("归一化对齐结果"):
|
340 |
+
with gr.Group(elem_classes=["out-card"]):
|
341 |
+
with gr.Row():
|
342 |
+
aligned_id_out = gr.Image(type="filepath", label="对齐后的身份图", height=240)
|
343 |
+
aligned_hair_out = gr.Image(type="filepath", label="对齐后的发型图", height=240)
|
344 |
+
|
345 |
+
with gr.TabItem("秃头化结果"):
|
346 |
+
with gr.Group(elem_classes=["out-card"]):
|
347 |
+
bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
|
348 |
+
|
349 |
+
# 逻辑保持不变
|
350 |
+
run_btn.click(fn=inference,
|
351 |
+
inputs=[id_input, hair_input],
|
352 |
+
outputs=[aligned_id_out, aligned_hair_out, bald_id_out,
|
353 |
+
video_out, frames_state, frame_slider, frame_preview])
|
354 |
+
|
355 |
+
|
356 |
+
def _on_slide(frames, idx):
|
357 |
+
if not frames:
|
358 |
+
return gr.update()
|
359 |
+
i = int(idx) - 1
|
360 |
+
i = max(0, min(i, len(frames) - 1))
|
361 |
+
return gr.update(value=frames[i])
|
362 |
+
|
363 |
+
|
364 |
+
frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
|
365 |
+
|
366 |
+
|
367 |
+
def _clear():
|
368 |
+
return None, None, None, None, None
|
369 |
+
|
370 |
+
|
371 |
+
clear_btn.click(_clear, None,
|
372 |
+
[id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
|
373 |
+
|
374 |
+
if __name__ == "__main__":
|
375 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
albucore==0.0.24
|
3 |
+
albumentations==2.0.8
|
4 |
+
annotated-types==0.7.0
|
5 |
+
antlr4-python3-runtime==4.9.3
|
6 |
+
av==14.4.0
|
7 |
+
Brotli
|
8 |
+
certifi
|
9 |
+
charset-normalizer
|
10 |
+
click==8.2.1
|
11 |
+
colorama
|
12 |
+
einops==0.8.1
|
13 |
+
filelock
|
14 |
+
fsspec
|
15 |
+
gitdb==4.0.12
|
16 |
+
GitPython==3.1.44
|
17 |
+
gmpy2
|
18 |
+
graphviz==0.20.3
|
19 |
+
hf-xet==1.1.5
|
20 |
+
huggingface-hub==0.30.0
|
21 |
+
idna
|
22 |
+
Jinja2
|
23 |
+
kornia
|
24 |
+
MarkupSafe
|
25 |
+
mkl-service==2.4.0
|
26 |
+
mkl_fft
|
27 |
+
mkl_random
|
28 |
+
mpmath
|
29 |
+
networkx
|
30 |
+
numpy
|
31 |
+
omegaconf==2.3.0
|
32 |
+
opencv-python==4.11.0.86
|
33 |
+
opencv-python-headless==4.11.0.86
|
34 |
+
packaging
|
35 |
+
peft==0.15.2
|
36 |
+
pillow
|
37 |
+
platformdirs==4.3.8
|
38 |
+
prodigyopt==1.1.2
|
39 |
+
protobuf==6.31.1
|
40 |
+
psutil
|
41 |
+
pydantic==2.11.7
|
42 |
+
pydantic_core==2.33.2
|
43 |
+
PySocks
|
44 |
+
PyYAML
|
45 |
+
regex==2024.11.6
|
46 |
+
requests
|
47 |
+
safetensors
|
48 |
+
scipy==1.15.3
|
49 |
+
sentencepiece==0.2.0
|
50 |
+
sentry-sdk==2.32.0
|
51 |
+
setproctitle==1.3.6
|
52 |
+
simsimd==6.4.9
|
53 |
+
smmap==5.0.2
|
54 |
+
stringzilla==3.12.5
|
55 |
+
sympy==1.13.1
|
56 |
+
tokenizers==0.21.1
|
57 |
+
torch==2.5.0
|
58 |
+
torchaudio==2.5.0
|
59 |
+
torchvision==0.20.0
|
60 |
+
torchviz==0.0.3
|
61 |
+
tqdm
|
62 |
+
transformers==4.52.3
|
63 |
+
triton==3.1.0
|
64 |
+
typing-inspection==0.4.1
|
65 |
+
typing_extensions
|
66 |
+
urllib3
|
67 |
+
wandb==0.20.1
|
test_stablehairv2.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from transformers import AutoTokenizer, CLIPVisionModelWithProjection
|
12 |
+
from diffusers import AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
|
13 |
+
from src.models.unet_3d import UNet3DConditionModel
|
14 |
+
from ref_encoder.reference_unet import CCProjection
|
15 |
+
from ref_encoder.latent_controlnet import ControlNetModel
|
16 |
+
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline as Hair3dPipeline
|
17 |
+
from src.utils.util import save_videos_grid
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
from HairMapper.hair_mapper_run import bald_head
|
20 |
+
|
21 |
+
|
22 |
+
# face align
|
23 |
+
def _maybe_align_image(image_path: str, output_size: int, prefer_cuda: bool = True):
|
24 |
+
"""Align and crop a face image to FFHQ-style using FFHQFaceAlignment if available.
|
25 |
+
Falls back to simple resize if alignment fails.
|
26 |
+
Returns an RGB uint8 numpy array of shape (H, W, 3).
|
27 |
+
"""
|
28 |
+
try:
|
29 |
+
ffhq_dir = os.path.join(os.path.dirname(__file__), 'FFHQFaceAlignment')
|
30 |
+
if ffhq_dir not in sys.path:
|
31 |
+
sys.path.insert(0, ffhq_dir)
|
32 |
+
# Lazy imports to avoid hard dependency if user doesn't enable alignment
|
33 |
+
from lib.landmarks_pytorch import LandmarksEstimation
|
34 |
+
from align import align_crop_image
|
35 |
+
|
36 |
+
# Read image as RGB uint8
|
37 |
+
img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
38 |
+
if img_bgr is None:
|
39 |
+
raise RuntimeError(f"Failed to read image: {image_path}")
|
40 |
+
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8')
|
41 |
+
|
42 |
+
device = torch.device('cuda' if prefer_cuda and torch.cuda.is_available() else 'cpu')
|
43 |
+
le = LandmarksEstimation(type='2D')
|
44 |
+
|
45 |
+
img_tensor = torch.tensor(np.transpose(img, (2, 0, 1))).float().to(device)
|
46 |
+
with torch.no_grad():
|
47 |
+
landmarks, _ = le.detect_landmarks(img_tensor.unsqueeze(0), detected_faces=None)
|
48 |
+
if len(landmarks) > 0:
|
49 |
+
lm = np.asarray(landmarks[0].detach().cpu().numpy())
|
50 |
+
aligned = align_crop_image(image=img, landmarks=lm, transform_size=output_size)
|
51 |
+
if aligned is None or aligned.size == 0:
|
52 |
+
return cv2.resize(img, (output_size, output_size))
|
53 |
+
return aligned
|
54 |
+
else:
|
55 |
+
return cv2.resize(img, (output_size, output_size))
|
56 |
+
except Exception:
|
57 |
+
# Silent fallback to simple resize on any failure
|
58 |
+
img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
59 |
+
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8') if img_bgr is not None else None
|
60 |
+
if img is None:
|
61 |
+
raise
|
62 |
+
return cv2.resize(img, (output_size, output_size))
|
63 |
+
|
64 |
+
|
65 |
+
def log_validation(
|
66 |
+
vae, tokenizer, image_encoder, denoising_unet,
|
67 |
+
args, device, logger, cc_projection,
|
68 |
+
controlnet, hair_encoder, feature_extractor=None
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Run inference on validation pairs and save generated videos.
|
72 |
+
"""
|
73 |
+
logger.info("Starting validation inference...")
|
74 |
+
|
75 |
+
# Initialize inference pipeline
|
76 |
+
pipeline = Hair3dPipeline.from_pretrained(
|
77 |
+
args.pretrained_model_name_or_path,
|
78 |
+
image_encoder=image_encoder,
|
79 |
+
feature_extractor=feature_extractor,
|
80 |
+
controlnet=controlnet,
|
81 |
+
vae=vae,
|
82 |
+
tokenizer=tokenizer,
|
83 |
+
denoising_unet=denoising_unet,
|
84 |
+
safety_checker=None,
|
85 |
+
revision=args.revision,
|
86 |
+
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
87 |
+
).to(device)
|
88 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
89 |
+
pipeline.set_progress_bar_config(disable=True)
|
90 |
+
|
91 |
+
# Create output directory
|
92 |
+
output_dir = os.path.join(args.output_dir, "validation")
|
93 |
+
os.makedirs(output_dir, exist_ok=True)
|
94 |
+
|
95 |
+
print(output_dir)
|
96 |
+
|
97 |
+
# Generate camera trajectory
|
98 |
+
x_coords = [0.4 * np.sin(2 * np.pi * i / 120) for i in range(60)]
|
99 |
+
y_coords = [-0.05 + 0.3 * np.cos(2 * np.pi * i / 120) for i in range(60)]
|
100 |
+
X = [x_coords[0]]
|
101 |
+
Y = [y_coords[0]]
|
102 |
+
for i in range(20):
|
103 |
+
X.append(x_coords[i * 3 + 2])
|
104 |
+
Y.append(y_coords[i * 3 + 2])
|
105 |
+
x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
|
106 |
+
y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
|
107 |
+
|
108 |
+
# # Load reference images
|
109 |
+
# id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
|
110 |
+
# id_image = cv2.resize(id_image, (512, 512))
|
111 |
+
# Load reference images (optionally align)
|
112 |
+
align_enabled = getattr(args, 'align_before_infer', True)
|
113 |
+
align_size = getattr(args, 'align_size', 1024)
|
114 |
+
prefer_cuda = True if device.type == 'cuda' else False
|
115 |
+
if align_enabled:
|
116 |
+
id_image = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
117 |
+
else:
|
118 |
+
id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
|
119 |
+
id_image = cv2.resize(id_image, (512, 512))
|
120 |
+
|
121 |
+
# ===== ���� HairMapper ͺͷ�� =====
|
122 |
+
temp_bald_path = os.path.join(args.output_dir, "bald_id.png")
|
123 |
+
cv2.imwrite(temp_bald_path, cv2.cvtColor(id_image, cv2.COLOR_RGB2BGR)) # �������ͼ
|
124 |
+
bald_head(temp_bald_path, temp_bald_path) # ͺͷ�������DZ���
|
125 |
+
# ���¼���ͺͷͼ�� (RGB)
|
126 |
+
id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
|
127 |
+
id_image = cv2.resize(id_image, (512, 512))
|
128 |
+
|
129 |
+
id_list = [id_image for _ in range(12)]
|
130 |
+
if align_enabled:
|
131 |
+
hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
132 |
+
prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
133 |
+
else:
|
134 |
+
hair_image = cv2.cvtColor(cv2.imread(args.validation_hairs[0]), cv2.COLOR_BGR2RGB)
|
135 |
+
hair_image = cv2.resize(hair_image, (512, 512))
|
136 |
+
prompt_img = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB)
|
137 |
+
prompt_img = cv2.resize(prompt_img, (512, 512))
|
138 |
+
hair_image = cv2.resize(hair_image, (512, 512))
|
139 |
+
prompt_img = cv2.resize(prompt_img, (512, 512))
|
140 |
+
|
141 |
+
prompt_img = [prompt_img]
|
142 |
+
|
143 |
+
# Perform inference and save videos
|
144 |
+
for idx in range(args.num_validation_images):
|
145 |
+
result = pipeline(
|
146 |
+
prompt="",
|
147 |
+
negative_prompt="",
|
148 |
+
num_inference_steps=30,
|
149 |
+
guidance_scale=1.5,
|
150 |
+
width=512,
|
151 |
+
height=512,
|
152 |
+
controlnet_condition=id_list,
|
153 |
+
controlnet_conditioning_scale=1.0,
|
154 |
+
generator=torch.Generator(device).manual_seed(args.seed),
|
155 |
+
ref_image=hair_image,
|
156 |
+
prompt_img=prompt_img,
|
157 |
+
reference_encoder=hair_encoder,
|
158 |
+
poses=None,
|
159 |
+
x=x_tensor,
|
160 |
+
y=y_tensor,
|
161 |
+
video_length=21,
|
162 |
+
context_frames=12,
|
163 |
+
)
|
164 |
+
video = torch.cat([result.videos, result.videos], dim=0)
|
165 |
+
video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
|
166 |
+
save_videos_grid(video, video_path, n_rows=5, fps=24)
|
167 |
+
logger.info(f"Saved generated video: {video_path}")
|
168 |
+
|
169 |
+
|
170 |
+
def parse_args():
|
171 |
+
parser = argparse.ArgumentParser(
|
172 |
+
description="Inference script for 3D hairstyle generation"
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--pretrained_model_name_or_path", type=str, required=True,
|
176 |
+
help="Path or ID of the pretrained pipeline"
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--model_path", type=str, required=True,
|
180 |
+
help="Path or ID of the pretrained pipeline"
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--image_encoder", type=str, required=True,
|
184 |
+
help="Path or ID of the CLIP vision encoder"
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--controlnet_model_name_or_path", type=str, default=None,
|
188 |
+
help="Path or ID of the ControlNet model"
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--revision", type=str, default=None,
|
192 |
+
help="Model revision or Git reference"
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--output_dir", type=str, default="inference_output",
|
196 |
+
help="Directory to save inference results"
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--seed", type=int, default=42,
|
200 |
+
help="Random seed for reproducibility"
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--num_validation_images", type=int, default=3,
|
204 |
+
help="Number of videos to generate per input pair"
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--validation_ids", type=str, nargs='+', required=True,
|
208 |
+
help="Path(s) to identity conditioning images"
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--validation_hairs", type=str, nargs='+', required=True,
|
212 |
+
help="Path(s) to hairstyle reference images"
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--use_fp16", action="store_true",
|
216 |
+
help="Enable fp16 inference"
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--align_before_infer", action="store_true", default=True,
|
220 |
+
help="Align and crop input images to FFHQ style before inference"
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--align_size", type=int, default=1024,
|
224 |
+
help="Output size for aligned images when alignment is enabled"
|
225 |
+
)
|
226 |
+
return parser.parse_args()
|
227 |
+
|
228 |
+
|
229 |
+
def main():
|
230 |
+
args = parse_args()
|
231 |
+
# Setup device and logger
|
232 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
233 |
+
logging.basicConfig(
|
234 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
235 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
236 |
+
level=logging.INFO,
|
237 |
+
)
|
238 |
+
logger = logging.getLogger(__name__)
|
239 |
+
|
240 |
+
# Set random seed
|
241 |
+
torch.manual_seed(args.seed)
|
242 |
+
if device.type == "cuda":
|
243 |
+
torch.cuda.manual_seed_all(args.seed)
|
244 |
+
|
245 |
+
# Load models
|
246 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
247 |
+
args.pretrained_model_name_or_path,
|
248 |
+
subfolder="tokenizer",
|
249 |
+
revision=args.revision
|
250 |
+
)
|
251 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
252 |
+
args.image_encoder,
|
253 |
+
revision=args.revision
|
254 |
+
).to(device)
|
255 |
+
vae = AutoencoderKL.from_pretrained(
|
256 |
+
args.pretrained_model_name_or_path,
|
257 |
+
subfolder="vae",
|
258 |
+
revision=args.revision
|
259 |
+
).to(device)
|
260 |
+
|
261 |
+
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
262 |
+
|
263 |
+
unet2 = UNet2DConditionModel.from_pretrained(
|
264 |
+
args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision,
|
265 |
+
torch_dtype=torch.float16
|
266 |
+
).to(device)
|
267 |
+
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
|
268 |
+
padding=unet2.conv_in.padding)
|
269 |
+
conv_in_8.requires_grad_(False)
|
270 |
+
unet2.conv_in.requires_grad_(False)
|
271 |
+
torch.nn.init.zeros_(conv_in_8.weight)
|
272 |
+
conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
|
273 |
+
conv_in_8.bias.copy_(unet2.conv_in.bias)
|
274 |
+
unet2.conv_in = conv_in_8
|
275 |
+
|
276 |
+
# Load or initialize ControlNet
|
277 |
+
controlnet = ControlNetModel.from_unet(unet2).to(device)
|
278 |
+
# state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu'))
|
279 |
+
# state_dict2 = torch.load(args.model_path, map_location=torch.device('cpu'))
|
280 |
+
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu'))
|
281 |
+
controlnet.load_state_dict(state_dict2, strict=False)
|
282 |
+
|
283 |
+
# Load 3D UNet motion module
|
284 |
+
prefix = "motion_module"
|
285 |
+
ckpt_num = "4140000"
|
286 |
+
save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
|
287 |
+
|
288 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
289 |
+
args.pretrained_model_name_or_path,
|
290 |
+
save_path,
|
291 |
+
subfolder="unet",
|
292 |
+
|
293 |
+
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
294 |
+
).to(device)
|
295 |
+
|
296 |
+
# Load projection and hair encoder
|
297 |
+
cc_projection = CCProjection().to(device)
|
298 |
+
state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location=torch.device('cpu'))
|
299 |
+
cc_projection.load_state_dict(state_dict3, strict=False)
|
300 |
+
|
301 |
+
from ref_encoder.reference_unet import ref_unet
|
302 |
+
Hair_Encoder = ref_unet.from_pretrained(
|
303 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
|
304 |
+
device_map=None, ignore_mismatched_sizes=True
|
305 |
+
).to(device)
|
306 |
+
|
307 |
+
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
|
308 |
+
# state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin"))
|
309 |
+
Hair_Encoder.load_state_dict(state_dict2, strict=False)
|
310 |
+
|
311 |
+
# Run validation inference
|
312 |
+
log_validation(
|
313 |
+
vae, tokenizer, image_encoder, denoising_unet,
|
314 |
+
args, device, logger,
|
315 |
+
cc_projection, controlnet, Hair_Encoder
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
main()
|