# coding=utf-8 # Copyright 2021 The IDEA Authors. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from fengshen.models.zen1 import ZenModel from dataclasses import dataclass from fengshen.models.megatron_t5 import T5EncoderModel from fengshen.models.roformer import RoFormerModel from fengshen.models.longformer import LongformerModel # from fengshen.models.cocolm.modeling_cocolm import COCOLMForSequenceClassification import numpy as np import os from tqdm import tqdm import json import torch import pytorch_lightning as pl import argparse from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor from torch.utils.data import Dataset, DataLoader from torch.utils.data._utils.collate import default_collate from transformers import ( BertModel, BertConfig, MegatronBertModel, MegatronBertConfig, AutoModel, AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, ) # os.environ["CUDA_VISIBLE_DEVICES"] = '6' model_dict = {'huggingface-bert': BertModel, 'fengshen-roformer': RoFormerModel, 'huggingface-megatron_bert': MegatronBertModel, 'fengshen-megatron_t5': T5EncoderModel, 'fengshen-longformer': LongformerModel, # 'fengshen-zen1': ZenModel, 'huggingface-auto': AutoModelForSequenceClassification, } class TaskDataset(Dataset): def __init__(self, data_path, args, label2id): super().__init__() self.args = args self.label2id = label2id self.max_length = args.max_length self.data = self.load_data(data_path, args) def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] def load_data(self, data_path, args): with open(data_path, 'r', encoding='utf8') as f: lines = f.readlines() samples = [] for line in tqdm(lines): data = json.loads(line) text_id = int(data[args.id_name] ) if args.id_name in data.keys() else 0 texta = data[args.texta_name] if args.texta_name in data.keys( ) else '' textb = data[args.textb_name] if args.textb_name in data.keys( ) else '' labels = self.label2id[data[args.label_name] ] if args.label_name in data.keys() else 0 samples.append({args.texta_name: texta, args.textb_name: textb, args.label_name: labels, 'id': text_id}) return samples @dataclass class TaskCollator: args = None tokenizer = None def __call__(self, samples): sample_list = [] for item in samples: if item[self.args.texta_name] != '' and item[self.args.textb_name] != '': if self.args.model_type != 'fengshen-roformer': encode_dict = self.tokenizer.encode_plus( [item[self.args.texta_name], item[self.args.textb_name]], max_length=self.args.max_length, padding='max_length', truncation='longest_first') else: encode_dict = self.tokenizer.encode_plus( [item[self.args.texta_name] + self.tokenizer.eos_token+item[self.args.textb_name]], max_length=self.args.max_length, padding='max_length', truncation='longest_first') else: encode_dict = self.tokenizer.encode_plus( item[self.args.texta_name], max_length=self.args.max_length, padding='max_length', truncation='longest_first') sample = {} for k, v in encode_dict.items(): sample[k] = torch.tensor(v) sample['labels'] = torch.tensor(item[self.args.label_name]).long() sample['id'] = item['id'] sample_list.append(sample) return default_collate(sample_list) class TaskDataModel(pl.LightningDataModule): @staticmethod def add_data_specific_args(parent_args): parser = parent_args.add_argument_group('TASK NAME DataModel') parser.add_argument('--data_dir', default='./data', type=str) parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--train_data', default='train.json', type=str) parser.add_argument('--valid_data', default='dev.json', type=str) parser.add_argument('--test_data', default='test.json', type=str) parser.add_argument('--train_batchsize', default=16, type=int) parser.add_argument('--valid_batchsize', default=32, type=int) parser.add_argument('--max_length', default=128, type=int) parser.add_argument('--texta_name', default='text', type=str) parser.add_argument('--textb_name', default='sentence2', type=str) parser.add_argument('--label_name', default='label', type=str) parser.add_argument('--id_name', default='id', type=str) parser.add_argument('--dataset_name', default=None, type=str) return parent_args def __init__(self, args): super().__init__() self.train_batchsize = args.train_batchsize self.valid_batchsize = args.valid_batchsize self.tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_path) self.collator = TaskCollator() self.collator.args = args self.collator.tokenizer = self.tokenizer if args.dataset_name is None: self.label2id, self.id2label = self.load_schema(os.path.join( args.data_dir, args.train_data), args) self.train_data = TaskDataset(os.path.join( args.data_dir, args.train_data), args, self.label2id) self.valid_data = TaskDataset(os.path.join( args.data_dir, args.valid_data), args, self.label2id) self.test_data = TaskDataset(os.path.join( args.data_dir, args.test_data), args, self.label2id) else: import datasets ds = datasets.load_dataset(args.dataset_name) self.train_data = ds['train'] self.valid_data = ds['validation'] self.test_data = ds['test'] self.save_hyperparameters(args) def train_dataloader(self): return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batchsize, pin_memory=False, collate_fn=self.collator) def val_dataloader(self): return DataLoader(self.valid_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, collate_fn=self.collator) def predict_dataloader(self): return DataLoader(self.test_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, collate_fn=self.collator) def load_schema(self, data_path, args): with open(data_path, 'r', encoding='utf8') as f: lines = f.readlines() label_list = [] for line in tqdm(lines): data = json.loads(line) labels = data[args.label_name] if args.label_name in data.keys( ) else 0 if labels not in label_list: label_list.append(labels) label2id, id2label = {}, {} for i, k in enumerate(label_list): label2id[k] = i id2label[i] = k return label2id, id2label class taskModel(torch.nn.Module): def __init__(self, args): super().__init__() self.args = args print('args mode type:', args.model_type) self.bert_encoder = model_dict[args.model_type].from_pretrained( args.pretrained_model_path) self.config = self.bert_encoder.config self.cls_layer = torch.nn.Linear( in_features=self.config.hidden_size, out_features=self.args.num_labels) self.loss_func = torch.nn.CrossEntropyLoss() def forward(self, input_ids, attention_mask, token_type_ids, labels=None): if self.args.model_type == 'fengshen-megatron_t5': bert_output = self.bert_encoder( input_ids=input_ids, attention_mask=attention_mask) # (bsz, seq, dim) encode = bert_output.last_hidden_state[:, 0, :] else: bert_output = self.bert_encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) # (bsz, seq, dim) encode = bert_output[1] logits = self.cls_layer(encode) if labels is not None: loss = self.loss_func(logits, labels.view(-1,)) return loss, logits else: return 0, logits class LitModel(pl.LightningModule): @staticmethod def add_model_specific_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--num_labels', default=2, type=int) return parent_args def __init__(self, args, num_data): super().__init__() self.args = args self.num_data = num_data self.model = model_dict[args.model_type].from_pretrained( args.pretrained_model_path) self.save_hyperparameters(args) def setup(self, stage) -> None: train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() # Calculate total steps if self.trainer.max_epochs > 0: world_size = self.trainer.world_size tb_size = self.hparams.train_batchsize * max(1, world_size) ab_size = self.trainer.accumulate_grad_batches self.total_steps = (len(train_loader.dataset) * self.trainer.max_epochs // tb_size) // ab_size else: self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches print('Total steps: {}' .format(self.total_steps)) def training_step(self, batch, batch_idx): del batch['id'] output = self.model(**batch) loss, logits = output[0], output[1] acc = self.comput_metrix(logits, batch['labels']) self.log('train_loss', loss) self.log('train_acc', acc) return loss def comput_metrix(self, logits, labels): y_pred = torch.argmax(logits, dim=-1) y_pred = y_pred.view(size=(-1,)) y_true = labels.view(size=(-1,)).float() corr = torch.eq(y_pred, y_true) acc = torch.sum(corr.float())/labels.size()[0] return acc def validation_step(self, batch, batch_idx): del batch['id'] output = self.model(**batch) loss, logits = output[0], output[1] acc = self.comput_metrix(logits, batch['labels']) self.log('val_loss', loss) self.log('val_acc', acc, sync_dist=True) def predict_step(self, batch, batch_idx): ids = batch['id'] del batch['id'] output = self.model(**batch) return {ids, output.logits} def configure_optimizers(self): from fengshen.models.model_utils import configure_optimizers return configure_optimizers(self) class TaskModelCheckpoint: @staticmethod def add_argparse_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--monitor', default='train_loss', type=str) parser.add_argument('--mode', default='min', type=str) parser.add_argument('--dirpath', default='./log/', type=str) parser.add_argument( '--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) parser.add_argument('--save_top_k', default=3, type=float) parser.add_argument('--every_n_train_steps', default=100, type=float) parser.add_argument('--save_weights_only', default=True, type=bool) return parent_args def __init__(self, args): self.callbacks = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode=args.mode, every_n_train_steps=args.every_n_train_steps, save_weights_only=args.save_weights_only, dirpath=args.dirpath, every_n_epochs=1, filename=args.filename) def save_test(data, args, data_model, rank): file_name = args.output_save_path + f'.{rank}' with open(file_name, 'w', encoding='utf-8') as f: idx = 0 for i in range(len(data)): ids, batch = data[i] for id, sample in zip(ids, batch): tmp_result = dict() label_id = np.argmax(sample.cpu().numpy()) tmp_result['id'] = id.item() tmp_result['label'] = data_model.id2label[label_id] json_data = json.dumps(tmp_result, ensure_ascii=False) f.write(json_data+'\n') idx += 1 print('save the result to '+file_name) def main(): pl.seed_everything(42) total_parser = argparse.ArgumentParser("TASK NAME") total_parser.add_argument('--pretrained_model_path', default='', type=str) total_parser.add_argument('--output_save_path', default='./predict.json', type=str) total_parser.add_argument('--model_type', default='huggingface-bert', type=str) # * Args for data preprocessing total_parser = TaskDataModel.add_data_specific_args(total_parser) # * Args for training total_parser = pl.Trainer.add_argparse_args(total_parser) total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) # * Args for base model from fengshen.models.model_utils import add_module_args total_parser = add_module_args(total_parser) total_parser = LitModel.add_model_specific_args(total_parser) args = total_parser.parse_args() print(args.pretrained_model_path) checkpoint_callback = TaskModelCheckpoint(args).callbacks early_stop_callback = EarlyStopping( monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max") lr_monitor = LearningRateMonitor(logging_interval='step') trainer = pl.Trainer.from_argparse_args(args, callbacks=[ checkpoint_callback, lr_monitor, early_stop_callback] ) data_model = TaskDataModel(args) model = LitModel(args, len(data_model.train_dataloader())) trainer.fit(model, data_model) result = trainer.predict( model, data_model, ckpt_path=trainer.checkpoint_callback.best_model_path) save_test(result, args, data_model, trainer.global_rank) if __name__ == "__main__": main()