|
<!--Copyright 2023 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
|
|
โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
|
rendered properly in your Markdown viewer. |
|
|
|
--> |
|
|
|
# ์๋ ์์ฑ ์ธ์[[automatic-speech-recognition]] |
|
|
|
[[open-in-colab]] |
|
|
|
<Youtube id="TksaY_FDgnk"/> |
|
|
|
์๋ ์์ฑ ์ธ์(Automatic Speech Recognition, ASR)์ ์์ฑ ์ ํธ๋ฅผ ํ
์คํธ๋ก ๋ณํํ์ฌ ์์ฑ ์
๋ ฅ ์ํ์ค๋ฅผ ํ
์คํธ ์ถ๋ ฅ์ ๋งคํํฉ๋๋ค. |
|
Siri์ Alexa์ ๊ฐ์ ๊ฐ์ ์ด์์คํดํธ๋ ASR ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ผ์์ ์ผ๋ก ์ฌ์ฉ์๋ฅผ ๋๊ณ ์์ผ๋ฉฐ, ํ์ ์ค ๋ผ์ด๋ธ ์บก์
๋ฐ ๋ฉ๋ชจ ์์ฑ๊ณผ ๊ฐ์ ์ ์ฉํ ์ฌ์ฉ์ ์นํ์ ์์ฉ ํ๋ก๊ทธ๋จ๋ ๋ง์ด ์์ต๋๋ค. |
|
|
|
์ด ๊ฐ์ด๋์์ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค: |
|
|
|
1. [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) ๋ฐ์ดํฐ ์ธํธ์์ [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base)๋ฅผ ๋ฏธ์ธ ์กฐ์ ํ์ฌ ์ค๋์ค๋ฅผ ํ
์คํธ๋ก ๋ณํํฉ๋๋ค. |
|
2. ๋ฏธ์ธ ์กฐ์ ํ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํฉ๋๋ค. |
|
|
|
<Tip> |
|
์ด ํํ ๋ฆฌ์ผ์์ ์ค๋ช
ํ๋ ์์
์ ๋ค์ ๋ชจ๋ธ ์ํคํ
์ฒ์ ์ํด ์ง์๋ฉ๋๋ค: |
|
|
|
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> |
|
|
|
[Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [M-CTC-T](../model_doc/mctct), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm) |
|
|
|
<!--End of the generated tip--> |
|
|
|
</Tip> |
|
|
|
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: |
|
|
|
```bash |
|
pip install transformers datasets evaluate jiwer |
|
``` |
|
|
|
Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์
๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. ํ ํฐ์ ์
๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์. |
|
|
|
```py |
|
>>> from huggingface_hub import notebook_login |
|
|
|
>>> notebook_login() |
|
``` |
|
|
|
## MInDS-14 ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-minds-14-dataset]] |
|
|
|
๋จผ์ , ๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) ๋ฐ์ดํฐ ์ธํธ์ ์ผ๋ถ๋ถ์ ๊ฐ์ ธ์ค์ธ์. |
|
์ด๋ ๊ฒ ํ๋ฉด ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ํ๋ จ์ ์๊ฐ์ ๋ค์ด๊ธฐ ์ ์ ๋ชจ๋ ๊ฒ์ด ์๋ํ๋์ง ์คํํ๊ณ ๊ฒ์ฆํ ์ ์์ต๋๋ค. |
|
|
|
```py |
|
>>> from datasets import load_dataset, Audio |
|
|
|
>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train[:100]") |
|
``` |
|
|
|
[`~Dataset.train_test_split`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ `train`์ ํ๋ จ ์ธํธ์ ํ
์คํธ ์ธํธ๋ก ๋๋์ธ์: |
|
|
|
```py |
|
>>> minds = minds.train_test_split(test_size=0.2) |
|
``` |
|
|
|
๊ทธ๋ฆฌ๊ณ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ์ธํ์ธ์: |
|
|
|
```py |
|
>>> minds |
|
DatasetDict({ |
|
train: Dataset({ |
|
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'], |
|
num_rows: 16 |
|
}) |
|
test: Dataset({ |
|
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'], |
|
num_rows: 4 |
|
}) |
|
}) |
|
``` |
|
|
|
๋ฐ์ดํฐ ์ธํธ์๋ `lang_id`์ `english_transcription`๊ณผ ๊ฐ์ ์ ์ฉํ ์ ๋ณด๊ฐ ๋ง์ด ํฌํจ๋์ด ์์ง๋ง, ์ด ๊ฐ์ด๋์์๋ `audio`์ `transcription`์ ์ด์ ์ ๋ง์ถ ๊ฒ์
๋๋ค. ๋ค๋ฅธ ์ด์ [`~datasets.Dataset.remove_columns`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ ๊ฑฐํ์ธ์: |
|
|
|
```py |
|
>>> minds = minds.remove_columns(["english_transcription", "intent_class", "lang_id"]) |
|
``` |
|
|
|
์์๋ฅผ ๋ค์ ํ๋ฒ ํ์ธํด๋ณด์ธ์: |
|
|
|
```py |
|
>>> minds["train"][0] |
|
{'audio': {'array': array([-0.00024414, 0. , 0. , ..., 0.00024414, |
|
0.00024414, 0.00024414], dtype=float32), |
|
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', |
|
'sampling_rate': 8000}, |
|
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', |
|
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"} |
|
``` |
|
|
|
๋ ๊ฐ์ ํ๋๊ฐ ์์ต๋๋ค: |
|
|
|
- `audio`: ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํ๊ธฐ ์ํด ํธ์ถํด์ผ ํ๋ ์์ฑ ์ ํธ์ 1์ฐจ์ `array(๋ฐฐ์ด)` |
|
- `transcription`: ๋ชฉํ ํ
์คํธ |
|
|
|
## ์ ์ฒ๋ฆฌ[[preprocess]] |
|
|
|
๋ค์์ผ๋ก ์ค๋์ค ์ ํธ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ Wav2Vec2 ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ต๋๋ค: |
|
|
|
```py |
|
>>> from transformers import AutoProcessor |
|
|
|
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base") |
|
``` |
|
|
|
MInDS-14 ๋ฐ์ดํฐ ์ธํธ์ ์ํ๋ง ๋ ์ดํธ๋ 8000kHz์ด๋ฏ๋ก([๋ฐ์ดํฐ ์ธํธ ์นด๋](https://huggingface.co/datasets/PolyAI/minds14)์์ ํ์ธ), ์ฌ์ ํ๋ จ๋ Wav2Vec2 ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ๋ฐ์ดํฐ ์ธํธ๋ฅผ 16000kHz๋ก ๋ฆฌ์ํ๋งํด์ผ ํฉ๋๋ค: |
|
|
|
```py |
|
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000)) |
|
>>> minds["train"][0] |
|
{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ..., |
|
2.78103951e-04, 2.38446111e-04, 1.18740834e-04], dtype=float32), |
|
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', |
|
'sampling_rate': 16000}, |
|
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', |
|
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"} |
|
``` |
|
|
|
์์ 'transcription'์์ ๋ณผ ์ ์๋ฏ์ด ํ
์คํธ๋ ๋๋ฌธ์์ ์๋ฌธ์๊ฐ ์์ฌ ์์ต๋๋ค. Wav2Vec2 ํ ํฌ๋์ด์ ๋ ๋๋ฌธ์ ๋ฌธ์์ ๋ํด์๋ง ํ๋ จ๋์ด ์์ผ๋ฏ๋ก ํ
์คํธ๊ฐ ํ ํฌ๋์ด์ ์ ์ดํ์ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค: |
|
|
|
```py |
|
>>> def uppercase(example): |
|
... return {"transcription": example["transcription"].upper()} |
|
|
|
|
|
>>> minds = minds.map(uppercase) |
|
``` |
|
|
|
์ด์ ๋ค์ ์์
์ ์ํํ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ง๋ค์ด๋ณด๊ฒ ์ต๋๋ค: |
|
|
|
1. `audio` ์ด์ ํธ์ถํ์ฌ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํฉ๋๋ค. |
|
2. ์ค๋์ค ํ์ผ์์ `input_values`๋ฅผ ์ถ์ถํ๊ณ ํ๋ก์ธ์๋ก `transcription` ์ด์ ํ ํฐํํฉ๋๋ค. |
|
|
|
```py |
|
>>> def prepare_dataset(batch): |
|
... audio = batch["audio"] |
|
... batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"]) |
|
... batch["input_length"] = len(batch["input_values"][0]) |
|
... return batch |
|
``` |
|
|
|
์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets [`~datasets.Dataset.map`] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์. `num_proc` ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ก์ธ์ค ์๋ฅผ ๋๋ฆฌ๋ฉด `map`์ ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. [`~datasets.Dataset.remove_columns`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ํ์ง ์์ ์ด์ ์ ๊ฑฐํ์ธ์: |
|
|
|
```py |
|
>>> encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names["train"], num_proc=4) |
|
``` |
|
|
|
๐ค Transformers์๋ ์๋ ์์ฑ ์ธ์์ฉ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๊ฐ ์์ผ๋ฏ๋ก ์์ ๋ฐฐ์น๋ฅผ ์์ฑํ๋ ค๋ฉด [`DataCollatorWithPadding`]์ ์กฐ์ ํด์ผ ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ ํ
์คํธ์ ๋ ์ด๋ธ์ ๋ฐฐ์น์์ ๊ฐ์ฅ ๊ธด ์์์ ๊ธธ์ด์ ๋์ ์ผ๋ก ํจ๋ฉํ์ฌ ๊ธธ์ด๋ฅผ ๊ท ์ผํ๊ฒ ํฉ๋๋ค. `tokenizer` ํจ์์์ `padding=True`๋ฅผ ์ค์ ํ์ฌ ํ
์คํธ๋ฅผ ํจ๋ฉํ ์ ์์ง๋ง, ๋์ ํจ๋ฉ์ด ๋ ํจ์จ์ ์
๋๋ค. |
|
|
|
๋ค๋ฅธ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ๋ฌ๋ฆฌ ์ด ํน์ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ `input_values`์ `labels`์ ๋ํด ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ ์ฉํด์ผ ํฉ๋๋ค. |
|
|
|
```py |
|
>>> import torch |
|
|
|
>>> from dataclasses import dataclass, field |
|
>>> from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
>>> @dataclass |
|
... class DataCollatorCTCWithPadding: |
|
... processor: AutoProcessor |
|
... padding: Union[bool, str] = "longest" |
|
|
|
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
... # ์
๋ ฅ๊ณผ ๋ ์ด๋ธ์ ๋ถํ ํฉ๋๋ค |
|
... # ๊ธธ์ด๊ฐ ๋ค๋ฅด๊ณ , ๊ฐ๊ฐ ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ฌ์ฉํด์ผ ํ๊ธฐ ๋๋ฌธ์
๋๋ค |
|
... input_features = [{"input_values": feature["input_values"][0]} for feature in features] |
|
... label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
... batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt") |
|
|
|
... labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt") |
|
|
|
... # ํจ๋ฉ์ ๋ํด ์์ค์ ์ ์ฉํ์ง ์๋๋ก -100์ผ๋ก ๋์ฒดํฉ๋๋ค |
|
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
... batch["labels"] = labels |
|
|
|
... return batch |
|
``` |
|
|
|
์ด์ `DataCollatorForCTCWithPadding`์ ์ธ์คํด์คํํฉ๋๋ค: |
|
|
|
```py |
|
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest") |
|
``` |
|
|
|
## ํ๊ฐํ๊ธฐ[[evaluate]] |
|
|
|
ํ๋ จ ์ค์ ํ๊ฐ ์งํ๋ฅผ ํฌํจํ๋ฉด ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๋ฐ ๋์์ด ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๐ค [Evaluate](https://huggingface.co/docs/evaluate/index) ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ํ๊ฐ ๋ฐฉ๋ฒ์ ๋น ๋ฅด๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. |
|
์ด ์์
์์๋ [๋จ์ด ์ค๋ฅ์จ(Word Error Rate, WER)](https://huggingface.co/spaces/evaluate-metric/wer) ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ต๋๋ค. |
|
(ํ๊ฐ ์งํ๋ฅผ ๋ถ๋ฌ์ค๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๐ค Evaluate [๋๋ฌ๋ณด๊ธฐ](https://huggingface.co/docs/evaluate/a_quick_tour)๋ฅผ ์ฐธ์กฐํ์ธ์): |
|
|
|
```py |
|
>>> import evaluate |
|
|
|
>>> wer = evaluate.load("wer") |
|
``` |
|
|
|
๊ทธ๋ฐ ๋ค์ ์์ธก๊ฐ๊ณผ ๋ ์ด๋ธ์ [`~evaluate.EvaluationModule.compute`]์ ์ ๋ฌํ์ฌ WER์ ๊ณ์ฐํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค: |
|
|
|
```py |
|
>>> import numpy as np |
|
|
|
|
|
>>> def compute_metrics(pred): |
|
... pred_logits = pred.predictions |
|
... pred_ids = np.argmax(pred_logits, axis=-1) |
|
|
|
... pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
... pred_str = processor.batch_decode(pred_ids) |
|
... label_str = processor.batch_decode(pred.label_ids, group_tokens=False) |
|
|
|
... wer = wer.compute(predictions=pred_str, references=label_str) |
|
|
|
... return {"wer": wer} |
|
``` |
|
|
|
์ด์ `compute_metrics` ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ผ๋ฉฐ, ํ๋ จ์ ์ค์ ํ ๋ ์ด ํจ์๋ก ๋๋์์ฌ ๊ฒ์
๋๋ค. |
|
|
|
## ํ๋ จํ๊ธฐ[[train]] |
|
|
|
<frameworkcontent> |
|
<pt> |
|
<Tip> |
|
|
|
[`Trainer`]๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ด ์ต์ํ์ง ์๋ค๋ฉด, [์ฌ๊ธฐ](../training#train-with-pytorch-trainer)์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํด๋ณด์ธ์! |
|
|
|
</Tip> |
|
|
|
์ด์ ๋ชจ๋ธ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! [`AutoModelForCTC`]๋ก Wav2Vec2๋ฅผ ๊ฐ์ ธ์ค์ธ์. `ctc_loss_reduction` ๋งค๊ฐ๋ณ์๋ก CTC ์์ค์ ์ ์ฉํ ์ถ์(reduction) ๋ฐฉ๋ฒ์ ์ง์ ํ์ธ์. ๊ธฐ๋ณธ๊ฐ์ธ ํฉ๊ณ ๋์ ํ๊ท ์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ์ข์ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค: |
|
|
|
```py |
|
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer |
|
|
|
>>> model = AutoModelForCTC.from_pretrained( |
|
... "facebook/wav2vec2-base", |
|
... ctc_loss_reduction="mean", |
|
... pad_token_id=processor.tokenizer.pad_token_id, |
|
... ) |
|
``` |
|
|
|
์ด์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค: |
|
|
|
1. [`TrainingArguments`]์์ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์. `output_dir`์ ๋ชจ๋ธ์ ์ ์ฅํ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๋ ์ ์ผํ ํ์ ๋งค๊ฐ๋ณ์์
๋๋ค. `push_to_hub=True`๋ฅผ ์ค์ ํ์ฌ ๋ชจ๋ธ์ Hub์ ์
๋ก๋ ํ ์ ์์ต๋๋ค(๋ชจ๋ธ์ ์
๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค). [`Trainer`]๋ ๊ฐ ์ํญ๋ง๋ค WER์ ํ๊ฐํ๊ณ ํ๋ จ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. |
|
2. ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ, `compute_metrics` ํจ์์ ํจ๊ป [`Trainer`]์ ํ๋ จ ์ธ์๋ฅผ ์ ๋ฌํ์ธ์. |
|
3. [`~Trainer.train`]์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์. |
|
|
|
```py |
|
>>> training_args = TrainingArguments( |
|
... output_dir="my_awesome_asr_mind_model", |
|
... per_device_train_batch_size=8, |
|
... gradient_accumulation_steps=2, |
|
... learning_rate=1e-5, |
|
... warmup_steps=500, |
|
... max_steps=2000, |
|
... gradient_checkpointing=True, |
|
... fp16=True, |
|
... group_by_length=True, |
|
... evaluation_strategy="steps", |
|
... per_device_eval_batch_size=8, |
|
... save_steps=1000, |
|
... eval_steps=1000, |
|
... logging_steps=25, |
|
... load_best_model_at_end=True, |
|
... metric_for_best_model="wer", |
|
... greater_is_better=False, |
|
... push_to_hub=True, |
|
... ) |
|
|
|
>>> trainer = Trainer( |
|
... model=model, |
|
... args=training_args, |
|
... train_dataset=encoded_minds["train"], |
|
... eval_dataset=encoded_minds["test"], |
|
... tokenizer=processor.feature_extractor, |
|
... data_collator=data_collator, |
|
... compute_metrics=compute_metrics, |
|
... ) |
|
|
|
>>> trainer.train() |
|
``` |
|
|
|
ํ๋ จ์ด ์๋ฃ๋๋ฉด ๋ชจ๋๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [`~transformers.Trainer.push_to_hub`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ Hub์ ๊ณต์ ํ์ธ์: |
|
|
|
```py |
|
>>> trainer.push_to_hub() |
|
``` |
|
</pt> |
|
</frameworkcontent> |
|
|
|
<Tip> |
|
|
|
์๋ ์์ฑ ์ธ์์ ์ํด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ ์์ธํ ์์ ๋ ์์ด ์๋ ์์ฑ ์ธ์์ ์ํ [๋ธ๋ก๊ทธ ํฌ์คํธ](https://huggingface.co/blog/fine-tune-wav2vec2-english)์ ๋ค๊ตญ์ด ์๋ ์์ฑ ์ธ์์ ์ํ [ํฌ์คํธ](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
</Tip> |
|
|
|
## ์ถ๋ก ํ๊ธฐ[[inference]] |
|
|
|
์ข์์, ์ด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค! |
|
|
|
์ถ๋ก ์ ์ฌ์ฉํ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค์ธ์. ํ์ํ ๊ฒฝ์ฐ ์ค๋์ค ํ์ผ์ ์ํ๋ง ๋น์จ์ ๋ชจ๋ธ์ ์ํ๋ง ๋ ์ดํธ์ ๋ง๊ฒ ๋ฆฌ์ํ๋งํ๋ ๊ฒ์ ์์ง ๋ง์ธ์! |
|
|
|
```py |
|
>>> from datasets import load_dataset, Audio |
|
|
|
>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train") |
|
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
>>> sampling_rate = dataset.features["audio"].sampling_rate |
|
>>> audio_file = dataset[0]["audio"]["path"] |
|
``` |
|
|
|
์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ํํด๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [`pipeline`]์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋ ์์ฑ ์ธ์์ ์ํ `pipeline`์ ์ธ์คํด์คํํ๊ณ ์ค๋์ค ํ์ผ์ ์ ๋ฌํ์ธ์: |
|
|
|
```py |
|
>>> from transformers import pipeline |
|
|
|
>>> transcriber = pipeline("automatic-speech-recognition", model="stevhliu/my_awesome_asr_minds_model") |
|
>>> transcriber(audio_file) |
|
{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'} |
|
``` |
|
|
|
<Tip> |
|
|
|
ํ
์คํธ๋ก ๋ณํ๋ ๊ฒฐ๊ณผ๊ฐ ๊ฝค ๊ด์ฐฎ์ง๋ง ๋ ์ข์ ์๋ ์์ต๋๋ค! ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ ๋ง์ ์์ ๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์! |
|
|
|
</Tip> |
|
|
|
`pipeline`์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ์ฌํํ ์๋ ์์ต๋๋ค: |
|
|
|
<frameworkcontent> |
|
<pt> |
|
์ค๋์ค ํ์ผ๊ณผ ํ
์คํธ๋ฅผ ์ ์ฒ๋ฆฌํ๊ณ PyTorch ํ
์๋ก `input`์ ๋ฐํํ ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค์ธ์: |
|
|
|
```py |
|
>>> from transformers import AutoProcessor |
|
|
|
>>> processor = AutoProcessor.from_pretrained("stevhliu/my_awesome_asr_mind_model") |
|
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") |
|
``` |
|
|
|
์
๋ ฅ์ ๋ชจ๋ธ์ ์ ๋ฌํ๊ณ ๋ก์ง์ ๋ฐํํ์ธ์: |
|
|
|
```py |
|
>>> from transformers import AutoModelForCTC |
|
|
|
>>> model = AutoModelForCTC.from_pretrained("stevhliu/my_awesome_asr_mind_model") |
|
>>> with torch.no_grad(): |
|
... logits = model(**inputs).logits |
|
``` |
|
|
|
๊ฐ์ฅ ๋์ ํ๋ฅ ์ `input_ids`๋ฅผ ์์ธกํ๊ณ , ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์์ธก๋ `input_ids`๋ฅผ ๋ค์ ํ
์คํธ๋ก ๋์ฝ๋ฉํ์ธ์: |
|
|
|
```py |
|
>>> import torch |
|
|
|
>>> predicted_ids = torch.argmax(logits, dim=-1) |
|
>>> transcription = processor.batch_decode(predicted_ids) |
|
>>> transcription |
|
['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'] |
|
``` |
|
</pt> |
|
</frameworkcontent> |