File size: 4,082 Bytes
0f9e661 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
## Training with Paired Data (pix2pix-turbo)
Here, we show how to train a pix2pix-turbo model using paired data.
We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.
### Step 1. Get the Dataset
- First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
```
bash scripts/download_fill50k.sh
```
- Our training scripts expect the dataset to be in the following format:
```
data
βββ dataset_name
β βββ train_A
β β βββ 000000.png
β β βββ 000001.png
β β βββ ...
β βββ train_B
β β βββ 000000.png
β β βββ 000001.png
β β βββ ...
β βββ train_prompts.json
|
| βββ test_A
β β βββ 000000.png
β β βββ 000001.png
β β βββ ...
β βββ test_B
β β βββ 000000.png
β β βββ 000001.png
β β βββ ...
β βββ test_prompts.json
```
### Step 2. Train the Model
- Initialize the `accelerate` environment with the following command:
```
accelerate config
```
- Run the following command to train the model.
```
accelerate launch src/train_pix2pix_turbo.py \
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
--output_dir="output/pix2pix_turbo/fill50k" \
--dataset_folder="data/my_fill50k" \
--resolution=512 \
--train_batch_size=2 \
--enable_xformers_memory_efficient_attention --viz_freq 25 \
--track_val_fid \
--report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
```
- Additional optional flags:
- `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
- `--viz_freq`: Frequency of visualizing the results during training.
### Step 3. Monitor the training progress
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
<div>
<p align="center">
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
</p>
</div>
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
- Screenshots of the training progress are shown below:
- Step 0:
<div>
<p align="center">
<img src='../assets/examples/training_step_0.png' align="center" width=800px>
</p>
</div>
- Step 500:
<div>
<p align="center">
<img src='../assets/examples/training_step_500.png' align="center" width=800px>
</p>
</div>
- Step 6000:
<div>
<p align="center">
<img src='../assets/examples/training_step_6000.png' align="center" width=800px>
</p>
</div>
### Step 4. Running Inference with the trained models
- You can run inference using the trained model using the following command:
```
python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
--input_image "data/my_fill50k/test_A/40000.png" \
--prompt "violet circle with orange background" \
--output_dir "outputs"
```
- The above command should generate the following output:
<table>
<tr>
<th>Model Input</th>
<th>Model Output</th>
</tr>
<tr>
<td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
<td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
</tr>
</table>
|