lengocduc195's picture
pushNe
2359bda
"""
This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph.
Optionally, you can also provide a dev file.
The fine-tuned model is stored in the output/model_name folder.
Usage:
python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt]
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask
from transformers import Trainer, TrainingArguments
import sys
import gzip
from datetime import datetime
if len(sys.argv) < 3:
print("Usage: python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt]")
exit()
model_name = sys.argv[1]
per_device_train_batch_size = 64
save_steps = 1000 #Save model every 1k steps
num_train_epochs = 3 #Number of epochs
use_fp16 = False #Set to True, if your GPU supports FP16 operations
max_length = 100 #Max length for a text input
do_whole_word_mask = True #If set to true, whole words are masked
mlm_prob = 0.15 #Probability that a word is replaced by a [MASK] token
# Load the model
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
output_dir = "output/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
print("Save checkpoints to:", output_dir)
##### Load our training datasets
train_sentences = []
train_path = sys.argv[2]
with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn:
for line in fIn:
line = line.strip()
if len(line) >= 10:
train_sentences.append(line)
print("Train sentences:", len(train_sentences))
dev_sentences = []
if len(sys.argv) >= 4:
dev_path = sys.argv[3]
with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn:
for line in fIn:
line = line.strip()
if len(line) >= 10:
dev_sentences.append(line)
print("Dev sentences:", len(dev_sentences))
#A dataset wrapper, that tokenizes our data on-the-fly
class TokenizedSentencesDataset:
def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False):
self.tokenizer = tokenizer
self.sentences = sentences
self.max_length = max_length
self.cache_tokenization = cache_tokenization
def __getitem__(self, item):
if not self.cache_tokenization:
return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
if isinstance(self.sentences[item], str):
self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
return self.sentences[item]
def __len__(self):
return len(self.sentences)
train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length)
dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None
##### Training arguments
if do_whole_word_mask:
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
else:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=num_train_epochs,
evaluation_strategy="steps" if dev_dataset is not None else "no",
per_device_train_batch_size=per_device_train_batch_size,
eval_steps=save_steps,
save_steps=save_steps,
logging_steps=save_steps,
save_total_limit=1,
prediction_loss_only=True,
fp16=use_fp16
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=dev_dataset
)
print("Save tokenizer to:", output_dir)
tokenizer.save_pretrained(output_dir)
trainer.train()
print("Save model to:", output_dir)
model.save_pretrained(output_dir)
print("Training done")