''' This is a training script for finetuning ESM. I am going to freeze the parameters in the head and unfreeze the last N layers in the model. ''' import os import fuson_plm.training.config as config # Set the WANDB_API_KEY environment variable os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES import torch import numpy as np import pandas as pd import tqdm from datetime import datetime import wandb import pytz import sys from transformers import AdamW from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update from fuson_plm.training.model import FusOnpLM from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler from fuson_plm.training.plot import make_train_val_test_bd_plot def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True): # Log the model's initial state n_layers = model.count_encoder_layers() total_params = sum(p.numel() for p in model.parameters()) total_head_params = sum(p.numel() for p in model.lm_head.parameters()) log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}') log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}') log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}') # Freeze the model to start model.freeze_model() n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log_update(f'Froze all {model.n_layers} model layers') log_update(f'\tTrainable params: {n_trainable_params}') # Unfreeze the last n layers model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad]) num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad) num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad) log_update(f'Unfroze final {n_unfrozen_layers} layers') log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}') log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}") log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}") def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'): """ Train the model """ # Loop over epochs log_update("\n") for epoch in range(start_epoch, start_epoch+n_epochs): if mask_rate_scheduler is not None: mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch model.train() total_train_loss = 0 total_weighted_train_loss = 0 total_train_masked_tokens = 0 log_update(f"Epoch {epoch}") # Loop over train data with progress bar with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar: for batch_idx, (inputs, prob) in pbar: # Take a step with the mask rate scheduler, if there is one. masking_rate = mask_percentage if mask_rate_scheduler is not None: mask_rate_scheduler.step() masking_rate = mask_rate_scheduler.get_masking_rate() log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}") # Move tensors inputs = {k: v.to(device) for k, v in inputs.items()} prob = prob.to(device) # Mask based on probability vectors masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate) # Forward pass and update optimizer.zero_grad() outputs = model(**masked_inputs) loss = outputs.loss loss.backward() optimizer.step() # Number of masked tokens num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() # Loss calculations and wandb log total_train_loss += loss.item() total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens total_train_masked_tokens += num_masked_tokens wandb.log({"batch_loss": loss.item()}) # Save a checkpoint at the end of each epoch checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}') model.save_model(checkpoint_path, optimizer=optimizer) log_update(f'\nSaved checkpoint to {checkpoint_path}') # Calculate and log average training loss on wandb n_train_batches = len(train_loader) avg_train_loss = total_train_loss / n_train_batches avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens train_perplexity = np.exp(avg_weighted_train_loss) wandb.log({"epoch": epoch, "total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss, "avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss, "train_perplexity": train_perplexity}) # Track curve stats for easy re-plotting of training curves later train_stats_df = pd.DataFrame(data={ "epoch": [epoch], "total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss], "avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss], "train_perplexity": [train_perplexity] }) if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a') else: # make new file if necessary train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False) # Validation loop model.eval() total_val_loss = 0 total_weighted_val_loss = 0 total_val_masked_tokens = 0 with torch.no_grad(): # No gradients needed # Loop over val data with progress bar with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar: for batch_idx, (inputs, prob) in vbar: # Move tensors inputs = {k: v.to(device) for k, v in inputs.items()} prob = prob.to(device) # Mask based on probability vectors ## FIXED 15% masking for the validation set masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15) # Forward pass outputs = model(**masked_inputs) val_loss = outputs.loss # Number of masked tokens num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() # Loss calculations total_val_loss += val_loss.item() total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens total_val_masked_tokens += num_masked_tokens # Calculate and log avg. loss and perplexity (wandb and locally) n_val_batches = len(val_loader) avg_val_loss = total_val_loss / n_val_batches # avg per batch avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token val_perplexity = np.exp(avg_weighted_val_loss) wandb.log({"epoch": epoch, "total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss, "avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss, "val_perplexity": val_perplexity}) # Track curve stats for easy re-plotting of training curves later val_stats_df = pd.DataFrame(data={ "epoch": [epoch], "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], "val_perplexity": [val_perplexity] }) if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a') else: # make new file if necessary val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False) log_update(f"Epoch: {epoch}") log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}") log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): """ """ model.to(device) model.eval() total_test_loss = 0 total_weighted_test_loss = 0 total_test_masked_tokens = 0 with torch.no_grad(): # No gradients needed # Loop over test data (no progress bar) with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar: for batch_idx, (inputs, prob) in tbar: # Move tensors inputs = {k: v.to(device) for k, v in inputs.items()} prob = prob.to(device) # Mask based on probability vectors ### FIXED 15% masking for the testing set masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15) # Forward pass outputs = model(**masked_inputs) test_loss = outputs.loss # Number of masked tokens num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() # Loss calculations total_test_loss += test_loss.item() total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens total_test_masked_tokens += num_masked_tokens # Compute and log avg. loss and perplexity n_test_batches = len(test_loader) avg_test_loss = total_test_loss / n_test_batches avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens test_perplexity = np.exp(avg_weighted_test_loss) log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}") # Save to dataframe for plotting test_stats_df = pd.DataFrame(data={ "total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss], "avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss], "test_perplexity": [test_perplexity] }) test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval def check_env_variables(): log_update("\nChecking on environment variables...") log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}") log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False): """ Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch. Also prepares Args: finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional) """ if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)): raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.") # if finetuning from scratch, initialize from scratch if finetune_from_scratch: log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch") model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer. model.to(device) prepare_model(model, n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) # Set the optimizer here, change it if we are finetuning from an old checkpoint optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) return model, optimizer # if not, initialize from starting ckpt else: log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}") model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True) model.to(device) prepare_model(model, n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) log_update(f"Loading optimizer state_dict from previous checkpoint") optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters())) optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device)) return model, optimizer def main(): # Set probability type to uniform; only option config.PROBABILITY_TYPE = "uniform" # Set run name (WANDB_NAME) kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}" timestamp = get_local_time() # make a mask tag _mask{config.MASK_PERCENTAGE} mask_tag = f"mask{config.MASK_PERCENTAGE}" if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}" # Define the train settings string and wandb name from this TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}" WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}' # Create directory for model checkpoints checkpoint_dir = f'checkpoints/{WANDB_NAME}' start_epoch = 1 # Determine if we're adding to an old log file or opening a new one logmode='w' # If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on # Also, load the optimizer from here if not(config.FINETUNE_FROM_SCRATCH): logmode='a' path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')] START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')] start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1 os.makedirs(f'checkpoints', exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) # Open log file LOG_PATH = f'{checkpoint_dir}/training_log.txt' ERR_PATH = f'{checkpoint_dir}/training_errors.txt' with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode): if not(config.FINETUNE_FROM_SCRATCH): log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n") log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n") # ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are! assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING # Print configurations print_configpy(config) # Verify that the environment variables are set correctly check_env_variables() # Check CUDA availability and set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log_update(f"\nUsing device: {device}") # Init wandb wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={ "batch_size": config.BATCH_SIZE, "epochs": config.EPOCHS, "learning_rate": config.LEARNING_RATE, }) # Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch. model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device, path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT, learning_rate=config.LEARNING_RATE, n_unfrozen_layers=config.N_UNFROZEN_LAYERS, unfreeze_query=config.UNFREEZE_QUERY, unfreeze_key=config.UNFREEZE_KEY, unfreeze_value=config.UNFREEZE_VALUE) # Initialize the tokenizer (independent of starting model for finetuning) tokenizer = model.tokenizer # Create DataLoader instances and perform sanity checks on them train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!! val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) # If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir) # Set up a masking rate scheduler, if one is needed mask_rate_scheduler = None if config.VAR_MASK_RATE: mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER, min_masking_rate=config.MASK_LOW, max_masking_rate=config.MASK_HIGH, total_batches=len(train_loader), total_steps=config.MASK_STEPS) # Train the model train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=config.EPOCHS, start_epoch = start_epoch, device=device, mask_rate_scheduler=mask_rate_scheduler, mask_percentage=config.MASK_PERCENTAGE, checkpoint_dir=checkpoint_dir) # Test the model test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir) if __name__ == "__main__": main()