|
<!--Copyright 2022 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
|
|
โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
|
rendered properly in your Markdown viewer. |
|
|
|
--> |
|
|
|
# ์คํฌ๋ฆฝํธ๋ก ์คํํ๊ธฐ[[train-with-a-script]] |
|
|
|
๐ค Transformers ๋
ธํธ๋ถ๊ณผ ํจ๊ป [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch), [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow), ๋๋ [JAX/Flax](https://github.com/huggingface/transformers/tree/main/examples/flax)๋ฅผ ์ฌ์ฉํด ํน์ ํ์คํฌ์ ๋ํ ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ฃผ๋ ์์ ์คํฌ๋ฆฝํธ๋ ์์ต๋๋ค. |
|
|
|
๋ํ [์ฐ๊ตฌ ํ๋ก์ ํธ](https://github.com/huggingface/transformers/tree/main/examples/research_projects) ๋ฐ [๋ ๊ฑฐ์ ์์ ](https://github.com/huggingface/transformers/tree/main/examples/legacy)์์ ๋๋ถ๋ถ ์ปค๋ฎค๋ํฐ์์ ์ ๊ณตํ ์คํฌ๋ฆฝํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. |
|
์ด๋ฌํ ์คํฌ๋ฆฝํธ๋ ์ ๊ทน์ ์ผ๋ก ์ ์ง ๊ด๋ฆฌ๋์ง ์์ผ๋ฉฐ ์ต์ ๋ฒ์ ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํธํ๋์ง ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ ํน์ ๋ฒ์ ์ ๐ค Transformers๋ฅผ ํ์๋ก ํฉ๋๋ค. |
|
|
|
์์ ์คํฌ๋ฆฝํธ๊ฐ ๋ชจ๋ ๋ฌธ์ ์์ ๋ฐ๋ก ์๋ํ๋ ๊ฒ์ ์๋๋ฉฐ, ํด๊ฒฐํ๋ ค๋ ๋ฌธ์ ์ ๋ง๊ฒ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณ๊ฒฝํด์ผ ํ ์๋ ์์ต๋๋ค. |
|
์ด๋ฅผ ์ํด ๋๋ถ๋ถ์ ์คํฌ๋ฆฝํธ์๋ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ์ด ๋์์์ด ํ์์ ๋ฐ๋ผ ์์ ํ ์ ์์ต๋๋ค. |
|
|
|
์์ ์คํฌ๋ฆฝํธ์ ๊ตฌํํ๊ณ ์ถ์ ๊ธฐ๋ฅ์ด ์์ผ๋ฉด pull request๋ฅผ ์ ์ถํ๊ธฐ ์ ์ [ํฌ๋ผ](https://discuss.huggingface.co/) ๋๋ [์ด์](https://github.com/huggingface/transformers/issues)์์ ๋
ผ์ํด ์ฃผ์ธ์. |
|
๋ฒ๊ทธ ์์ ์ ํ์ํ์ง๋ง ๊ฐ๋
์ฑ์ ํฌ์ํ๋ฉด์๊น์ง ๋ ๋ง์ ๊ธฐ๋ฅ์ ์ถ๊ฐํ๋ pull request๋ ๋ณํฉ(merge)ํ์ง ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค. |
|
|
|
์ด ๊ฐ์ด๋์์๋ [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) ๋ฐ [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization)์์ ์์ฝ ํ๋ จํ๋ |
|
์คํฌ๋ฆฝํธ ์์ ๋ฅผ ์คํํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํฉ๋๋ค. |
|
ํน๋ณํ ์ค๋ช
์ด ์๋ ํ ๋ชจ๋ ์์ ๋ ๋ ํ๋ ์์ํฌ ๋ชจ๋์์ ์๋ํ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. |
|
|
|
## ์ค์ ํ๊ธฐ[[setup]] |
|
|
|
์ต์ ๋ฒ์ ์ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ์คํํ๋ ค๋ฉด ์ ๊ฐ์ ํ๊ฒฝ์์ **์์ค๋ก๋ถํฐ ๐ค Transformers๋ฅผ ์ค์น**ํด์ผ ํฉ๋๋ค: |
|
|
|
```bash |
|
git clone https://github.com/huggingface/transformers |
|
cd transformers |
|
pip install . |
|
``` |
|
|
|
์ด์ ๋ฒ์ ์ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณด๋ ค๋ฉด ์๋ ํ ๊ธ์ ํด๋ฆญํ์ธ์: |
|
|
|
<details> |
|
<summary>์ด์ ๋ฒ์ ์ ๐ค Transformers ์์ </summary> |
|
<ul> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.5.1/examples">v4.5.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.4.2/examples">v4.4.2</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.3.3/examples">v4.3.3</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.2.2/examples">v4.2.2</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.1.1/examples">v4.1.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v4.0.1/examples">v4.0.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.5.1/examples">v3.5.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.4.0/examples">v3.4.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.3.1/examples">v3.3.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.2.0/examples">v3.2.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.1.0/examples">v3.1.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v3.0.2/examples">v3.0.2</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.11.0/examples">v2.11.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.10.0/examples">v2.10.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.9.1/examples">v2.9.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.8.0/examples">v2.8.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.7.0/examples">v2.7.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.6.0/examples">v2.6.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.5.1/examples">v2.5.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.4.0/examples">v2.4.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.3.0/examples">v2.3.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.2.0/examples">v2.2.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.1.0/examples">v2.1.1</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v2.0.0/examples">v2.0.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v1.2.0/examples">v1.2.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v1.1.0/examples">v1.1.0</a></li> |
|
<li><a href="https://github.com/huggingface/transformers/tree/v1.0.0/examples">v1.0.0</a></li> |
|
</ul> |
|
</details> |
|
|
|
๊ทธ๋ฆฌ๊ณ ๋ค์๊ณผ ๊ฐ์ด ๋ณต์ (clone)ํด์จ ๐ค Transformers ๋ฒ์ ์ ํน์ ๋ฒ์ (์: v3.5.1)์ผ๋ก ์ ํํ์ธ์: |
|
|
|
```bash |
|
git checkout tags/v3.5.1 |
|
``` |
|
|
|
์ฌ๋ฐ๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ฒ์ ์ ์ค์ ํ ํ ์ํ๋ ์์ ํด๋๋ก ์ด๋ํ์ฌ ์์ ๋ณ๋ก ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ํ ์๊ตฌ ์ฌํญ(requirements)์ ์ค์นํฉ๋๋ค: |
|
|
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
|
|
## ์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script]] |
|
|
|
<frameworkcontent> |
|
<pt> |
|
์์ ์คํฌ๋ฆฝํธ๋ ๐ค [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค. |
|
๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ๋ ์์ฝ ๊ธฐ๋ฅ์ ์ง์ํ๋ ์ํคํ
์ฒ์์ [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. |
|
๋ค์ ์๋ [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํฐ ์ธํธ์์ [T5-small](https://huggingface.co/t5-small)์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. |
|
T5 ๋ชจ๋ธ์ ํ๋ จ ๋ฐฉ์์ ๋ฐ๋ผ ์ถ๊ฐ `source_prefix` ์ธ์๊ฐ ํ์ํ๋ฉฐ, ์ด ํ๋กฌํํธ๋ ์์ฝ ์์
์์ T5์ ์๋ ค์ค๋๋ค. |
|
|
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py \ |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate |
|
``` |
|
</pt> |
|
<tf> |
|
์์ ์คํฌ๋ฆฝํธ๋ ๐ค [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค. |
|
๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ๋ ์์ฝ ๊ธฐ๋ฅ์ ์ง์ํ๋ ์ํคํ
์ฒ์์ Keras๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. |
|
๋ค์ ์๋ [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํฐ ์ธํธ์์ [T5-small](https://huggingface.co/t5-small)์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. |
|
T5 ๋ชจ๋ธ์ ํ๋ จ ๋ฐฉ์์ ๋ฐ๋ผ ์ถ๊ฐ `source_prefix` ์ธ์๊ฐ ํ์ํ๋ฉฐ, ์ด ํ๋กฌํํธ๋ ์์ฝ ์์
์์ T5์ ์๋ ค์ค๋๋ค. |
|
```bash |
|
python examples/tensorflow/summarization/run_summarization.py \ |
|
--model_name_or_path t5-small \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size 8 \ |
|
--per_device_eval_batch_size 16 \ |
|
--num_train_epochs 3 \ |
|
--do_train \ |
|
--do_eval |
|
``` |
|
</tf> |
|
</frameworkcontent> |
|
|
|
## ํผํฉ ์ ๋ฐ๋(mixed precision)๋ก ๋ถ์ฐ ํ๋ จํ๊ธฐ[[distributed-training-and-mixed-precision]] |
|
|
|
[Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) ํด๋์ค๋ ๋ถ์ฐ ํ๋ จ๊ณผ ํผํฉ ์ ๋ฐ๋(mixed precision)๋ฅผ ์ง์ํ๋ฏ๋ก ์คํฌ๋ฆฝํธ์์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
์ด ๋ ๊ฐ์ง ๊ธฐ๋ฅ์ ๋ชจ๋ ํ์ฑํํ๋ ค๋ฉด ๋ค์ ๋ ๊ฐ์ง๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค: |
|
|
|
- `fp16` ์ธ์๋ฅผ ์ถ๊ฐํด ํผํฉ ์ ๋ฐ๋(mixed precision)๋ฅผ ํ์ฑํํฉ๋๋ค. |
|
- `nproc_per_node` ์ธ์๋ฅผ ์ถ๊ฐํด ์ฌ์ฉํ GPU ๊ฐ์๋ฅผ ์ค์ ํฉ๋๋ค. |
|
|
|
```bash |
|
python -m torch.distributed.launch \ |
|
--nproc_per_node 8 pytorch/summarization/run_summarization.py \ |
|
--fp16 \ |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate |
|
``` |
|
|
|
TensorFlow ์คํฌ๋ฆฝํธ๋ ๋ถ์ฐ ํ๋ จ์ ์ํด [`MirroredStrategy`](https://www.tensorflow.org/guide/distributed_training#mirroredstrategy)๋ฅผ ํ์ฉํ๋ฉฐ, ํ๋ จ ์คํฌ๋ฆฝํธ์ ์ธ์๋ฅผ ์ถ๊ฐํ ํ์๊ฐ ์์ต๋๋ค. |
|
๋ค์ค GPU ํ๊ฒฝ์ด๋ผ๋ฉด, TensorFlow ์คํฌ๋ฆฝํธ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฌ๋ฌ ๊ฐ์ GPU๋ฅผ ์ฌ์ฉํฉ๋๋ค. |
|
|
|
## TPU ์์์ ์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script-on-a-tpu]] |
|
|
|
<frameworkcontent> |
|
<pt> |
|
Tensor Processing Units (TPUs)๋ ์ฑ๋ฅ์ ๊ฐ์ํํ๊ธฐ ์ํด ํน๋ณํ ์ค๊ณ๋์์ต๋๋ค. |
|
PyTorch๋ [XLA](https://www.tensorflow.org/xla) ๋ฅ๋ฌ๋ ์ปดํ์ผ๋ฌ์ ํจ๊ป TPU๋ฅผ ์ง์ํฉ๋๋ค(์์ธํ ๋ด์ฉ์ [์ฌ๊ธฐ](https://github.com/pytorch/xla/blob/master/README.md) ์ฐธ์กฐ). |
|
TPU๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด `xla_spawn.py` ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ณ `num_cores` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉํ๋ ค๋ TPU ์ฝ์ด ์๋ฅผ ์ค์ ํฉ๋๋ค. |
|
|
|
```bash |
|
python xla_spawn.py --num_cores 8 \ |
|
summarization/run_summarization.py \ |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate |
|
``` |
|
</pt> |
|
<tf> |
|
Tensor Processing Units (TPUs)๋ ์ฑ๋ฅ์ ๊ฐ์ํํ๊ธฐ ์ํด ํน๋ณํ ์ค๊ณ๋์์ต๋๋ค. |
|
TensorFlow ์คํฌ๋ฆฝํธ๋ TPU๋ฅผ ํ๋ จ์ ์ฌ์ฉํ๊ธฐ ์ํด [`TPUStrategy`](https://www.tensorflow.org/guide/distributed_training#tpustrategy)๋ฅผ ํ์ฉํฉ๋๋ค. |
|
TPU๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด TPU ๋ฆฌ์์ค์ ์ด๋ฆ์ `tpu` ์ธ์์ ์ ๋ฌํฉ๋๋ค. |
|
|
|
```bash |
|
python run_summarization.py \ |
|
--tpu name_of_tpu_resource \ |
|
--model_name_or_path t5-small \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size 8 \ |
|
--per_device_eval_batch_size 16 \ |
|
--num_train_epochs 3 \ |
|
--do_train \ |
|
--do_eval |
|
``` |
|
</tf> |
|
</frameworkcontent> |
|
|
|
## ๐ค Accelerate๋ก ์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script-with-accelerate]] |
|
|
|
๐ค [Accelerate](https://huggingface.co/docs/accelerate)๋ PyTorch ํ๋ จ ๊ณผ์ ์ ๋ํ ์์ ํ ๊ฐ์์ฑ์ ์ ์งํ๋ฉด์ ์ฌ๋ฌ ์ ํ์ ์ค์ (CPU ์ ์ฉ, ๋ค์ค GPU, TPU)์์ ๋ชจ๋ธ์ ํ๋ จํ ์ ์๋ ํตํฉ ๋ฐฉ๋ฒ์ ์ ๊ณตํ๋ PyTorch ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์
๋๋ค. |
|
๐ค Accelerate๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: |
|
|
|
> ์ฐธ๊ณ : Accelerate๋ ๋น ๋ฅด๊ฒ ๊ฐ๋ฐ ์ค์ด๋ฏ๋ก ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ค๋ฉด accelerate๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค. |
|
```bash |
|
pip install git+https://github.com/huggingface/accelerate |
|
``` |
|
|
|
`run_summarization.py` ์คํฌ๋ฆฝํธ ๋์ `run_summarization_no_trainer.py` ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. |
|
๐ค Accelerate ํด๋์ค๊ฐ ์ง์๋๋ ์คํฌ๋ฆฝํธ๋ ํด๋์ `task_no_trainer.py` ํ์ผ์ด ์์ต๋๋ค. |
|
๋ค์ ๋ช
๋ น์ ์คํํ์ฌ ๊ตฌ์ฑ ํ์ผ์ ์์ฑํ๊ณ ์ ์ฅํฉ๋๋ค: |
|
```bash |
|
accelerate config |
|
``` |
|
|
|
์ค์ ์ ํ
์คํธํ์ฌ ์ฌ๋ฐ๋ฅด๊ฒ ๊ตฌ์ฑ๋์๋์ง ํ์ธํฉ๋๋ค: |
|
|
|
```bash |
|
accelerate test |
|
``` |
|
|
|
์ด์ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค: |
|
|
|
```bash |
|
accelerate launch run_summarization_no_trainer.py \ |
|
--model_name_or_path t5-small \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir ~/tmp/tst-summarization |
|
``` |
|
|
|
## ์ฌ์ฉ์ ์ ์ ๋ฐ์ดํฐ ์ธํธ ์ฌ์ฉํ๊ธฐ[[use-a-custom-dataset]] |
|
|
|
์์ฝ ์คํฌ๋ฆฝํธ๋ ์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๊ฐ CSV ๋๋ JSON ํ์ผ์ธ ๊ฒฝ์ฐ ์ง์ํฉ๋๋ค. |
|
์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ์๋ ๋ช ๊ฐ์ง ์ถ๊ฐ ์ธ์๋ฅผ ์ง์ ํด์ผ ํฉ๋๋ค: |
|
|
|
- `train_file`๊ณผ `validation_file`์ ํ๋ จ ๋ฐ ๊ฒ์ฆ ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํฉ๋๋ค. |
|
- `text_column`์ ์์ฝํ ์
๋ ฅ ํ
์คํธ์
๋๋ค. |
|
- `summary_column`์ ์ถ๋ ฅํ ๋์ ํ
์คํธ์
๋๋ค. |
|
|
|
์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๋ ์์ฝ ์คํฌ๋ฆฝํธ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: |
|
|
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py \ |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--train_file path_to_csv_or_jsonlines_file \ |
|
--validation_file path_to_csv_or_jsonlines_file \ |
|
--text_column text_column_name \ |
|
--summary_column summary_column_name \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--overwrite_output_dir \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--predict_with_generate |
|
``` |
|
|
|
## ์คํฌ๋ฆฝํธ ํ
์คํธํ๊ธฐ[[test-a-script]] |
|
|
|
์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋์์ผ๋ก ํ๋ จ์ ์๋ฃํ๋๋ฐ ๊ฝค ์ค๋ ์๊ฐ์ด ๊ฑธ๋ฆฌ๊ธฐ ๋๋ฌธ์, ์์ ๋ฐ์ดํฐ ์ธํธ์์ ๋ชจ๋ ๊ฒ์ด ์์๋๋ก ์คํ๋๋์ง ํ์ธํ๋ ๊ฒ์ด ์ข์ต๋๋ค. |
|
|
|
๋ค์ ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ต๋ ์ํ ์๋ก ์๋ผ๋
๋๋ค: |
|
- `max_train_samples` |
|
- `max_eval_samples` |
|
- `max_predict_samples` |
|
|
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py \ |
|
--model_name_or_path t5-small \ |
|
--max_train_samples 50 \ |
|
--max_eval_samples 50 \ |
|
--max_predict_samples 50 \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate |
|
``` |
|
|
|
๋ชจ๋ ์์ ์คํฌ๋ฆฝํธ๊ฐ `max_predict_samples` ์ธ์๋ฅผ ์ง์ํ์ง๋ ์์ต๋๋ค. |
|
์คํฌ๋ฆฝํธ๊ฐ ์ด ์ธ์๋ฅผ ์ง์ํ๋์ง ํ์คํ์ง ์์ ๊ฒฝ์ฐ `-h` ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ํ์ธํ์ธ์: |
|
|
|
```bash |
|
examples/pytorch/summarization/run_summarization.py -h |
|
``` |
|
|
|
## ์ฒดํฌํฌ์ธํธ(checkpoint)์์ ํ๋ จ ์ด์ด์ ํ๊ธฐ[[resume-training-from-checkpoint]] |
|
|
|
๋ ๋ค๋ฅธ ์ ์ฉํ ์ต์
์ ์ด์ ์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๊ฒ์
๋๋ค. |
|
์ด๋ ๊ฒ ํ๋ฉด ํ๋ จ์ด ์ค๋จ๋๋๋ผ๋ ์ฒ์๋ถํฐ ๋ค์ ์์ํ์ง ์๊ณ ์ค๋จํ ๋ถ๋ถ๋ถํฐ ๋ค์ ์์ํ ์ ์์ต๋๋ค. |
|
์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์๋ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค. |
|
|
|
์ฒซ ๋ฒ์งธ๋ `output_dir previous_output_dir` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ `output_dir`์ ์ ์ฅ๋ ์ต์ ์ฒดํฌํฌ์ธํธ๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
์ด ๊ฒฝ์ฐ `overwrite_output_dir`์ ์ ๊ฑฐํด์ผ ํฉ๋๋ค: |
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--output_dir previous_output_dir \ |
|
--predict_with_generate |
|
``` |
|
|
|
๋ ๋ฒ์งธ๋ `resume_from_checkpoint path_to_specific_checkpoint` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ํน์ ์ฒดํฌํฌ์ธํธ ํด๋์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
|
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--resume_from_checkpoint path_to_specific_checkpoint \ |
|
--predict_with_generate |
|
``` |
|
|
|
## ๋ชจ๋ธ ๊ณต์ ํ๊ธฐ[[share-your-model]] |
|
|
|
๋ชจ๋ ์คํฌ๋ฆฝํธ๋ ์ต์ข
๋ชจ๋ธ์ [Model Hub](https://huggingface.co/models)์ ์
๋ก๋ํ ์ ์์ต๋๋ค. |
|
์์ํ๊ธฐ ์ ์ Hugging Face์ ๋ก๊ทธ์ธํ๋์ง ํ์ธํ์ธ์: |
|
```bash |
|
huggingface-cli login |
|
``` |
|
|
|
๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ์ `push_to_hub` ์ธ์๋ฅผ ์ถ๊ฐํฉ๋๋ค. |
|
์ด ์ธ์๋ Hugging Face ์ฌ์ฉ์ ์ด๋ฆ๊ณผ `output_dir`์ ์ง์ ๋ ํด๋ ์ด๋ฆ์ผ๋ก ์ ์ฅ์๋ฅผ ์์ฑํฉ๋๋ค. |
|
|
|
์ ์ฅ์์ ํน์ ์ด๋ฆ์ ์ง์ ํ๋ ค๋ฉด `push_to_hub_model_id` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๊ฐํฉ๋๋ค. |
|
์ ์ฅ์๋ ๋ค์์คํ์ด์ค ์๋์ ์๋์ผ๋ก ๋์ด๋ฉ๋๋ค. |
|
๋ค์ ์๋ ํน์ ์ ์ฅ์ ์ด๋ฆ์ผ๋ก ๋ชจ๋ธ์ ์
๋ก๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค: |
|
|
|
```bash |
|
python examples/pytorch/summarization/run_summarization.py |
|
--model_name_or_path t5-small \ |
|
--do_train \ |
|
--do_eval \ |
|
--dataset_name cnn_dailymail \ |
|
--dataset_config "3.0.0" \ |
|
--source_prefix "summarize: " \ |
|
--push_to_hub \ |
|
--push_to_hub_model_id finetuned-t5-cnn_dailymail \ |
|
--output_dir /tmp/tst-summarization \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate |
|
``` |