|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from sklearn import metrics |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader, ConcatDataset |
|
import pytorch_lightning as pl |
|
from collections import defaultdict |
|
from transformers import AutoConfig, AutoModel, get_cosine_schedule_with_warmup |
|
from loss import FocalLoss, LabelSmoothingCorrectionCrossEntropy |
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, file, tokenizer, max_len, mode='no_test'): |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
self.mode = mode |
|
|
|
self.ex_list = [] |
|
with open('./dataset/' + file, "r", encoding='utf-8') as f: |
|
for line in f: |
|
sample = json.loads(line) |
|
query = sample["query"] |
|
title = sample["title"] |
|
id = int(sample["id"]) |
|
if self.mode == 'no_test': |
|
relevant = int(sample["label"]) |
|
self.ex_list.append((query, title, relevant, id)) |
|
else: |
|
self.ex_list.append((query, title, id)) |
|
|
|
def __len__(self): |
|
return len(self.ex_list) |
|
|
|
def __getitem__(self, index): |
|
if self.mode == 'no_test': |
|
query, title, relevant, id = self.ex_list[index] |
|
else: |
|
query, title, id = self.ex_list[index] |
|
|
|
inputs = self.tokenizer.encode_plus( |
|
query, title, |
|
truncation=True, |
|
add_special_tokens=True, |
|
max_length=self.max_len, |
|
padding='max_length', |
|
return_token_type_ids=True |
|
) |
|
ids = inputs['input_ids'] |
|
mask = inputs['attention_mask'] |
|
token_type_ids = inputs["token_type_ids"] |
|
if self.mode == 'no_test': |
|
return { |
|
'ids': torch.tensor(ids, dtype=torch.long), |
|
'mask': torch.tensor(mask, dtype=torch.long), |
|
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), |
|
'targets': torch.tensor(relevant, dtype=torch.float), |
|
'id': torch.tensor(id, dtype=torch.long) |
|
} |
|
else: |
|
return { |
|
'ids': torch.tensor(ids, dtype=torch.long), |
|
'mask': torch.tensor(mask, dtype=torch.long), |
|
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), |
|
'id': torch.tensor(id, dtype=torch.long) |
|
} |
|
|
|
|
|
class CustomDataModule(pl.LightningDataModule): |
|
def __init__(self, args, tokenizer): |
|
super().__init__() |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
self.max_len = self.args.max_seq_length |
|
self.train_dataset = None |
|
self.val_dataset = None |
|
|
|
def setup(self, stage): |
|
data_path = "./dataset" |
|
assert os.path.exists(os.path.join(data_path, 'train.json')) |
|
assert os.path.exists(os.path.join(data_path, 'dev.json')) |
|
assert os.path.exists(os.path.join(data_path, 'test_public.json')) |
|
if stage == 'fit': |
|
self.train_dataset = CustomDataset('train.json', self.tokenizer, self.max_len) |
|
self.val_dataset = CustomDataset('dev.json', self.tokenizer, self.max_len) |
|
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len) |
|
elif stage == 'test': |
|
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len) |
|
|
|
def train_dataloader(self): |
|
full_dataset = ConcatDataset([self.train_dataset, self.val_dataset]) |
|
train_dataloader = DataLoader( |
|
full_dataset, |
|
batch_size=self.args.batch_size, |
|
num_workers=4, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True) |
|
return train_dataloader |
|
|
|
def val_dataloader(self): |
|
val_dataloader = DataLoader( |
|
self.test_dataset, |
|
batch_size=self.args.val_batch_size, |
|
num_workers=4, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
return val_dataloader |
|
|
|
def test_dataloader(self): |
|
test_dataloader = DataLoader( |
|
self.test_dataset, |
|
batch_size=self.args.val_batch_size, |
|
num_workers=4, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
return test_dataloader |
|
|
|
|
|
class CustomModel(pl.LightningModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.model = self.args.model_name |
|
self.cache_dir = self.args.model_path |
|
self.scheduler = self.args.scheduler |
|
self.step_scheduler_after = "batch" |
|
self.optimizer = self.args.optimizer |
|
self.pooler = self.args.use_original_pooler |
|
self.category = self.args.cate_performance |
|
self.loss_func = self.args.loss_function |
|
|
|
hidden_dropout_prob: float = 0.1 |
|
layer_norm_eps: float = 1e-7 |
|
|
|
config = AutoConfig.from_pretrained(self.model, cache_dir=self.cache_dir) |
|
|
|
config.update( |
|
{ |
|
"output_hidden_states": False, |
|
"hidden_dropout_prob": hidden_dropout_prob, |
|
"layer_norm_eps": layer_norm_eps, |
|
} |
|
) |
|
self.transformer = AutoModel.from_pretrained(self.model, config=config, cache_dir=self.cache_dir) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.linear = torch.nn.Linear(config.hidden_size, self.args.num_labels, bias=True) |
|
|
|
def configure_optimizers(self): |
|
"""Prepare optimizer and schedule""" |
|
model = self.transformer |
|
no_decay = ["bias", "LayerNorm.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.01, |
|
}, |
|
{ |
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
|
|
optimizer_index = ['Adam', 'AdamW'].index(self.optimizer) |
|
optimizer = [ |
|
torch.optim.Adam(optimizer_grouped_parameters, lr=self.args.learning_rate), |
|
torch.optim.AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)][optimizer_index] |
|
|
|
scheduler_index = ['StepLR', 'CosineWarmup', 'CosineAnnealingLR'].index(self.scheduler) |
|
scheduler = [ |
|
torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.warmup_step, |
|
gamma=self.args.warmup_proportion), |
|
get_cosine_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=int(self.args.warmup_proportion * self.total_steps), |
|
num_training_steps=self.total_steps, |
|
), |
|
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)][scheduler_index] |
|
|
|
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} |
|
return [optimizer], [scheduler] |
|
|
|
def setup(self, stage=None): |
|
if stage != "fit": |
|
return |
|
|
|
train_dataloader = self.trainer.datamodule.train_dataloader() |
|
gpus = 0 if self.trainer.gpus is None else self.trainer.gpus |
|
tb_size = self.args.batch_size * max(1, gpus) |
|
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) |
|
self.total_steps = (len(train_dataloader.dataset) // tb_size) // ab_size |
|
|
|
def loss(self, outputs, targets): |
|
lossf_index = ['CE', 'Focal', 'LSCE_correction'].index(self.loss_func) |
|
loss_fct = [nn.CrossEntropyLoss(), FocalLoss(), LabelSmoothingCorrectionCrossEntropy()][lossf_index] |
|
loss = loss_fct(outputs, targets) |
|
return loss |
|
|
|
def category_performance_measure(self, labels_right, labels_pred, num_label=3): |
|
text_labels = [i for i in range(num_label)] |
|
|
|
TP = dict.fromkeys(text_labels, 0) |
|
TP_FP = dict.fromkeys(text_labels, 0) |
|
TP_FN = dict.fromkeys(text_labels, 0) |
|
|
|
label_dict = defaultdict(list) |
|
for num in range(num_label): |
|
label_dict[num].append(str(num)) |
|
|
|
|
|
for i in range(0, len(labels_right)): |
|
TP_FP[labels_right[i]] += 1 |
|
TP_FN[labels_pred[i]] += 1 |
|
if labels_right[i] == labels_pred[i]: |
|
TP[labels_right[i]] += 1 |
|
|
|
|
|
results = [] |
|
for key in TP_FP: |
|
P = float(TP[key]) / float(TP_FP[key] + 1e-9) |
|
R = float(TP[key]) / float(TP_FN[key] + 1e-9) |
|
F1 = P * R * 2 / (P + R) if (P + R) != 0 else 0 |
|
|
|
results.append(F1) |
|
return results |
|
|
|
def monitor_metrics(self, outputs, targets): |
|
pred = torch.argmax(outputs, dim=1).cpu().numpy().tolist() |
|
targets = targets.int().cpu().numpy().tolist() |
|
if self.category: |
|
category_results = self.category_performance_measure( |
|
labels_right=targets, |
|
labels_pred=pred, |
|
num_label=self.args.num_labels |
|
) |
|
return {"f1": category_results} |
|
else: |
|
f1_score = metrics.f1_score(targets, pred, average="macro") |
|
return {"f1": f1_score} |
|
|
|
def forward(self, ids, mask, token_type_ids, labels): |
|
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids) |
|
|
|
if self.pooler: |
|
pooler_output = transformer_out.pooler_output |
|
else: |
|
sequence_output = transformer_out.last_hidden_state |
|
pooler_output = torch.mean(sequence_output, dim=1) |
|
logits = self.linear(self.dropout(pooler_output)) |
|
|
|
labels_hat = torch.argmax(logits, dim=1) |
|
correct_count = torch.sum(labels == labels_hat) |
|
return logits, correct_count |
|
|
|
def predict(self, ids, mask, token_type_ids): |
|
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids) |
|
pooler_output = transformer_out.pooler_output |
|
logits = self.linear(self.dropout(pooler_output)) |
|
logits = torch.argmax(logits, dim=1) |
|
return logits |
|
|
|
def training_step(self, batch, batch_idx): |
|
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] |
|
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) |
|
loss = self.loss(logits, labels.long()) |
|
f1 = self.monitor_metrics(logits, labels)["f1"] |
|
self.log("train_loss", loss, logger=True, prog_bar=True) |
|
self.log('train_acc', correct_count.float() / len(labels), logger=True, prog_bar=True) |
|
if self.category: |
|
self.log("train_f1_key0", f1[0], logger=True, prog_bar=True) |
|
self.log("train_f1_key1", f1[1], logger=True, prog_bar=True) |
|
self.log("train_f1_key2", f1[2], logger=True, prog_bar=True) |
|
else: |
|
self.log("train_f1", f1, logger=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] |
|
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) |
|
loss = self.loss(logits, labels.long()) |
|
f1 = self.monitor_metrics(logits, labels)["f1"] |
|
self.log("val_loss", loss, logger=True, prog_bar=True) |
|
self.log("val_acc", correct_count.float() / len(labels), logger=True, prog_bar=True) |
|
if self.category: |
|
self.log("val_f1_key0", f1[0], logger=True, prog_bar=True) |
|
self.log("val_f1_key1", f1[1], logger=True, prog_bar=True) |
|
self.log("val_f1_key2", f1[2], logger=True, prog_bar=True) |
|
else: |
|
self.log("val_f1", f1, logger=True, prog_bar=True) |
|
|
|
def test_step(self, batch, batch_idx): |
|
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] |
|
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) |
|
loss = self.loss(logits, labels.long()) |
|
f1 = self.monitor_metrics(logits, labels)["f1"] |
|
self.log("test_loss", loss, logger=True, prog_bar=True) |
|
self.log("test_acc", correct_count.float() / len(labels), logger=True, prog_bar=True) |
|
if self.category: |
|
self.log("test_f1_key0", f1[0], logger=True, prog_bar=True) |
|
self.log("test_f1_key1", f1[1], logger=True, prog_bar=True) |
|
self.log("test_f1_key2", f1[2], logger=True, prog_bar=True) |
|
else: |
|
self.log("test_f1", f1, logger=True, prog_bar=True) |
|
return {"test_loss": loss, "logits": logits, "labels": labels} |
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx): |
|
ids, mask, token_type_ids, id = batch['ids'], batch['mask'], batch['token_type_ids'], batch['id'] |
|
logits = self.predict(ids, mask, token_type_ids) |
|
return {'id': id.cpu().numpy().tolist(), 'logits': logits.cpu().numpy().tolist()} |
|
|