File size: 4,082 Bytes
0f9e661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
## Training with Paired Data (pix2pix-turbo)
Here, we show how to train a pix2pix-turbo model using paired data. 
We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.


### Step 1. Get the Dataset
- First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
    ```
    bash scripts/download_fill50k.sh
    ```

- Our training scripts expect the dataset to be in the following format:
    ```
    data
    β”œβ”€β”€ dataset_name
    β”‚   β”œβ”€β”€ train_A
    β”‚   β”‚   β”œβ”€β”€ 000000.png
    β”‚   β”‚   β”œβ”€β”€ 000001.png
    β”‚   β”‚   └── ...
    β”‚   β”œβ”€β”€ train_B
    β”‚   β”‚   β”œβ”€β”€ 000000.png
    β”‚   β”‚   β”œβ”€β”€ 000001.png
    β”‚   β”‚   └── ...
    β”‚   └── train_prompts.json
    |
    |   β”œβ”€β”€ test_A
    β”‚   β”‚   β”œβ”€β”€ 000000.png
    β”‚   β”‚   β”œβ”€β”€ 000001.png
    β”‚   β”‚   └── ...
    β”‚   β”œβ”€β”€ test_B
    β”‚   β”‚   β”œβ”€β”€ 000000.png
    β”‚   β”‚   β”œβ”€β”€ 000001.png
    β”‚   β”‚   └── ...
    β”‚   └── test_prompts.json
    ```


### Step 2. Train the Model
- Initialize the `accelerate` environment with the following command:
    ```
    accelerate config
    ```

- Run the following command to train the model. 
    ```
    accelerate launch src/train_pix2pix_turbo.py \
        --pretrained_model_name_or_path="stabilityai/sd-turbo" \
        --output_dir="output/pix2pix_turbo/fill50k" \
        --dataset_folder="data/my_fill50k" \
        --resolution=512 \
        --train_batch_size=2 \
        --enable_xformers_memory_efficient_attention --viz_freq 25 \
        --track_val_fid \
        --report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
    ```

- Additional optional flags:
    - `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
    - `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
    - `--viz_freq`: Frequency of visualizing the results during training.

### Step 3. Monitor the training progress
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.

- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
    <div>
        <p align="center">
        <img src='../assets/examples/training_evaluation.png' align="center" width=800px>
        </p>
    </div>


- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.

- Screenshots of the training progress are shown below:
    - Step 0:
    <div>
        <p align="center">
        <img src='../assets/examples/training_step_0.png' align="center" width=800px>
        </p>
    </div>
    
    - Step 500:
    <div>
        <p align="center">
        <img src='../assets/examples/training_step_500.png' align="center" width=800px>
        </p>
    </div>

    - Step 6000: 
    <div>
        <p align="center">
        <img src='../assets/examples/training_step_6000.png' align="center" width=800px>
        </p>
    </div>


### Step 4. Running Inference with the trained models

- You can run inference using the trained model using the following command:
    ```
    python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
        --input_image "data/my_fill50k/test_A/40000.png" \
        --prompt "violet circle with orange background" \
        --output_dir "outputs"
    ```

- The above command should generate the following output:
    <table>
    <tr>
    <th>Model Input</th>
    <th>Model Output</th>
    </tr>
    <tr>
    <td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
    <td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
    </tr>
    </table>