|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch Della model. """ |
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
from fengshen.models.deepVAE.deep_vae import DeepVAE |
|
from pytorch_lightning.core.lightning import LightningModule |
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config |
|
from transformers.models.bert.tokenization_bert import BertTokenizer |
|
from fengshen.models.deepVAE.latent_connector import GPT2ForDecoderLatentConnector, GPT2ForEncoderLatentConnector |
|
from transformers.optimization import AdamW, get_linear_schedule_with_warmup |
|
|
|
|
|
class DeepVAEModule(LightningModule): |
|
@classmethod |
|
def add_module_specific_args(cls, parser): |
|
group = parser.add_argument_group('vae', 'configurations') |
|
group.add_argument("--checkpoint_path", type=str, default=None) |
|
group.add_argument("--gpt2_model_path", type=str) |
|
group.add_argument("--beta_kl_constraints_start", default=1, type=float, |
|
help="min beta for all the latent z posterior vs prior kl loss") |
|
group.add_argument("--beta_kl_constraints_stop", default=1, type=float, |
|
help="max beta for all the latent z posterior vs prior kl loss") |
|
group.add_argument("--beta_n_cycles", default=30, type=int, |
|
help="number of cycles for kl loss ratio within an epoch") |
|
group.add_argument("--freebit_kl_constraints", default=.1, type=float, |
|
help="free bit for all the latent z kl loss") |
|
group.add_argument("--latent_dim", default=256, type=int, |
|
help="latent dimension of deepVAE Z") |
|
group.add_argument("--learning_rate", default=5e-5, type=float, |
|
help="The initial learning rate for Adam.") |
|
group.add_argument("--weight_decay", default=0.0, type=float, |
|
help="Weight deay if we apply some.") |
|
group.add_argument("--adam_epsilon", default=1e-8, type=float, |
|
help="Epsilon for Adam optimizer.") |
|
group.add_argument("--max_grad_norm", default=1.0, type=float, |
|
help="Max gradient norm.") |
|
group.add_argument("--warmup_steps", default=0, type=int, |
|
help="Linear warmup over warmup_steps.") |
|
group.add_argument("--CVAE", action='store_true', |
|
help="specify this argument if finetuning CVAE, otherwise ignore this argument") |
|
|
|
return parser |
|
|
|
@classmethod |
|
def load_model(cls, args, labels_dict=None): |
|
checkpoint = torch.load(os.path.join(args.checkpoint_path, 'mp_rank_00_model_states.pt')) |
|
|
|
latent_dim = checkpoint['latent_dim'] if ('latent_dim' in checkpoint.keys()) else args.latent_dim |
|
labels_dict = checkpoint['label_dict'] if ('label_dict' in checkpoint.keys()) else labels_dict |
|
|
|
enc_config = GPT2Config.from_pretrained(args.gpt2_model_path) |
|
tokenizer = BertTokenizer.from_pretrained(args.gpt2_model_path) |
|
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>'} |
|
|
|
tokenizer.add_special_tokens(special_tokens_dict) |
|
encoder_model = GPT2ForEncoderLatentConnector(config=enc_config) |
|
encoder_model.resize_token_embeddings(len(tokenizer)) |
|
|
|
dec_config = GPT2Config.from_pretrained(args.gpt2_model_path) |
|
decoder_model = GPT2ForDecoderLatentConnector(config=dec_config, latent_dim=latent_dim) |
|
decoder_model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vae_model = DeepVAE(encoder_model, decoder_model, latent_dim=latent_dim, |
|
hidden_dim=enc_config.hidden_size, layer_num=enc_config.num_hidden_layers, |
|
pad_token_id=tokenizer.pad_token_id, unk_token_id=tokenizer.unk_token_id, |
|
bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, |
|
CVAE=args.CVAE) |
|
|
|
|
|
anchor = 'module.model.' |
|
start = len(anchor) |
|
vae_dict = {key[start:]: val for key, val in checkpoint['module'].items() if anchor in key} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = vae_model.load_state_dict(vae_dict, strict=False) |
|
print(f"Vae model loading process: missing keys {missing_keys}, unexpected keys {unexpected_keys}") |
|
|
|
return vae_model, tokenizer |
|
|
|
def __init__( |
|
self, |
|
args, |
|
train_steps=0, |
|
labels_dict=None |
|
): |
|
super().__init__() |
|
|
|
self.args = args |
|
|
|
if args.checkpoint_path is not None: |
|
self.model, self.encoder_tokenizer, self.decoder_tokenizer, self.latent_dim, \ |
|
self.labels_dict, self.args = DeepVAEModule.load_model(self.args, labels_dict=labels_dict) |
|
else: |
|
self.encoder_tokenizer = BertTokenizer.from_pretrained(self.args.encoder_model_path) |
|
encoder_config = GPT2Config.from_pretrained(self.args.encoder_model_path) |
|
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'additional_special_tokens': ['<ENT>', '<ENS>']} |
|
self.encoder_tokenizer.add_special_tokens(special_tokens_dict) |
|
self.latent_dim = self.args.latent_dim |
|
encoder = GPT2ForEncoderLatentConnector.from_pretrained(self.args.encoder_model_path, config=encoder_config) |
|
|
|
encoder.resize_token_embeddings(len(self.encoder_tokenizer)) |
|
|
|
self.decoder_tokenizer = BertTokenizer.from_pretrained(self.args.decoder_model_path) |
|
self.decoder_tokenizer.add_special_tokens(special_tokens_dict) |
|
decoder_config = GPT2Config.from_pretrained(self.args.decoder_model_path) |
|
self.labels_dict = labels_dict |
|
decoder = GPT2ForDecoderLatentConnector.from_pretrained(self.args.decoder_model_path, config=decoder_config, |
|
latent_dim=self.latent_dim) |
|
|
|
|
|
decoder.resize_token_embeddings(len(self.decoder_tokenizer)) |
|
self.model = DeepVAE(encoder, decoder, latent_dim=self.args.latent_dim, |
|
hidden_dim=encoder_config.hidden_size, layer_num=encoder_config.num_hidden_layers, |
|
pad_token_id=self.decoder_tokenizer.pad_token_id, unk_token_id=self.decoder_tokenizer.unk_token_id, |
|
bos_token_id=self.decoder_tokenizer.bos_token_id, eos_token_id=self.decoder_tokenizer.eos_token_id, |
|
CVAE=args.CVAE) |
|
|
|
self.train_steps = train_steps |
|
|
|
self.beta_kl_constraints_list = self.get_cyclic_linear_beta_list(self.train_steps, |
|
start=args.beta_kl_constraints_start, stop=args.beta_kl_constraints_stop, n_cycle=args.beta_n_cycles) |
|
|
|
|
|
|
|
self.mlm_probability_list = self.get_constant_ratio(self.train_steps, 0.) |
|
|
|
|
|
def get_constant_ratio(self, n_steps, ratio): |
|
L = np.ones(n_steps) |
|
L *= ratio |
|
return L |
|
|
|
def get_decoder_beta_list(self, n_steps, start=0., stop=1.0, n_cycle=4): |
|
L = np.ones(n_steps) |
|
t_range = int(n_steps / n_cycle) |
|
for t_cur in range(n_steps): |
|
if t_cur > t_range: |
|
L[t_cur] = 0. |
|
else: |
|
ratio = t_cur / t_range |
|
value = stop - ratio * (stop-start) |
|
L[t_cur] = value |
|
return L |
|
|
|
def get_cyclic_linear_beta_list(self, n_steps, start=0.5, stop=1.0, n_cycle=4): |
|
L = np.ones(n_steps) |
|
t_range = int(n_steps / n_cycle) |
|
for t_cur in range(n_steps): |
|
loc = t_cur % t_range |
|
split_range = int(t_range * 0.25) |
|
if loc <= 2*split_range: |
|
value = start |
|
elif loc <= 3*split_range: |
|
ratio = (loc % split_range) / split_range |
|
value = ratio * (stop-start) |
|
else: |
|
value = stop |
|
L[t_cur] = value |
|
return L |
|
|
|
|
|
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
checkpoint['label_dict'] = self.labels_dict |
|
checkpoint['latent_dim'] = self.latent_dim |
|
|
|
def training_step(self, batch, batch_idx): |
|
if batch is None: |
|
loss = torch.Tensor([0.]).to(next(self.model.parameters()).device) |
|
loss.requires_grad = True |
|
return loss |
|
inputs, cond_inputs = batch, None |
|
if self.args.CVAE: |
|
inputs, cond_inputs = batch |
|
|
|
total_loss, rec_loss, total_kl_loss, layer_kl_loss = \ |
|
self.model(inputs, self.beta_kl_constraints_list[batch_idx], cond_inputs) |
|
|
|
for idx, pg in enumerate(self.optimizers().param_groups): |
|
self.log(f"learning_rate_{idx}", pg['lr']) |
|
unscaled_kl_constraint_loss = 0. if self.beta_kl_constraints_list[batch_idx] == 0. else total_kl_loss/self.beta_kl_constraints_list[batch_idx] |
|
self.log("total_loss", total_loss) |
|
self.log("total_kl_constraint_loss", total_kl_loss) |
|
self.log("unscaled_kl_constraint_loss", unscaled_kl_constraint_loss) |
|
self.log("beta_kl_constraints", self.beta_kl_constraints_list[batch_idx]) |
|
self.log("beta_mlm_probability", self.mlm_probability_list[batch_idx]) |
|
self.log("rec_loss", rec_loss) |
|
for idx, kl_loss in enumerate(layer_kl_loss): |
|
self.log(f"layer_{idx}_kl_loss", kl_loss.mean()) |
|
|
|
return total_loss |
|
|
|
def training_step_end(self, batch_parts): |
|
pass |
|
|
|
def training_epoch_end(self, outputs): |
|
pass |
|
|
|
def validation_step(self, batch, batch_idx): |
|
if batch is None: |
|
loss = torch.Tensor([0.]).to(next(self.model.parameters()).device) |
|
loss.requires_grad = True |
|
return loss |
|
inputs, cond_inputs = batch, None |
|
if self.args.CVAE: |
|
inputs, cond_inputs = batch |
|
|
|
total_loss, rec_loss, total_kl_loss, layer_kl_loss = self.model(inputs, 1., cond_inputs) |
|
|
|
self.log("val_total_loss", total_loss) |
|
self.log("val_kl_constraint_loss", total_kl_loss) |
|
self.log("val_recon_loss", rec_loss) |
|
for idx, kl_loss in enumerate(layer_kl_loss): |
|
self.log(f"layer_{idx}_kl_loss", kl_loss.mean()) |
|
return total_loss |
|
|
|
def validation_epoch_end(self, outputs): |
|
pass |
|
|
|
def test_step(self, batch, batch_idx): |
|
if batch is None: |
|
loss = torch.Tensor([0.]).to(next(self.model.parameters()).device) |
|
loss.requires_grad = True |
|
return loss |
|
inputs, cond_inputs = batch, None |
|
if self.args.CVAE: |
|
inputs, cond_inputs = batch |
|
total_loss, rec_loss, total_kl_loss, layer_kl_loss = self.model(inputs, 1., cond_inputs) |
|
self.log("test_total_loss", total_loss) |
|
self.log("test_recon_loss", rec_loss) |
|
self.log("test_kl_constraint_loss", total_kl_loss) |
|
for idx, kl_loss in enumerate(layer_kl_loss): |
|
self.log(f"layer_{idx}_kl_loss", kl_loss.mean()) |
|
return total_loss |
|
|
|
def configure_optimizers(self): |
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay}, |
|
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) |
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.train_steps) |
|
|
|
return {'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1 |
|
} |
|
} |
|
|