import numpy as np import pandas as pd import torch import os from torch.nn.functional import softmax from fuson_plm.utils.logging import log_update from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer from abc import ABC, abstractmethod #---------------------------------------------------------------------------------------------------------------------------------------------------- #### Masking Rate Scheduler base class and sub classes # abstract base class class MaskingRateScheduler(ABC): def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1): self.total_steps = total_steps self.min_masking_rate = min_masking_rate self.max_masking_rate = max_masking_rate self.current_step = last_step def step(self): self.current_step += 1 def reset(self): """Reset the scheduler to its initial state.""" self.current_step = -1 def get_masking_rate(self): progress = self.current_step / self.total_steps return self.compute_masking_rate(progress) @abstractmethod def compute_masking_rate(self, progress): """To be implemented by subclasses for specific increase functions.""" raise NotImplementedError("Subclasses must implement this method.") class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler): def compute_masking_rate(self, progress): # Use a cosine increase function cosine_increase = 0.5 * (1 - np.cos(np.pi * progress)) return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler): def compute_masking_rate(self, progress): # Avoid log(0) by clamping progress to a minimum of a small positive number progress = max(progress, 1e-10) log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1] return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler): def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps): super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate) self.num_steps = num_steps self.batch_interval = total_batches // (num_steps) # Adjusting to ensure max rate is included self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) # Include end rate in the steps def compute_masking_rate(self, progress): # Determine the current step based on the number of completed batches current_step = int(self.current_step / self.batch_interval) # Cap the step number to `num_steps - 1` to include the max rate at the final step current_step = min(current_step, self.num_steps - 1) # Calculate the masking rate for the current step masking_rate = self.min_masking_rate + current_step * self.rate_increment return masking_rate def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20): """ Initialize the mask rate scheduler and return it """ if scheduler_type=="cosine": return CosineIncreaseMaskingRateScheduler(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate) elif scheduler_type=="loglinear": return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate) elif scheduler_type=="stepwise": return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches, num_steps=total_steps, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate) else: raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise") # Adjusted Dataloader for the sequences and probability vectors class ProteinDataset(Dataset): def __init__(self, data_path, tokenizer, probability_type, max_length=512): self.dataframe = pd.read_csv(data_path) self.tokenizer = tokenizer self.probability_type=probability_type self.max_length = max_length self.set_probabilities() def __len__(self): return len(self.dataframe) def set_probabilities(self): if self.probability_type=="snp": self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'}) if self.probability_type=="uniform": self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1]) # make probabilities into numbers if they aren't already if type(self.dataframe['probabilities'][0]) == str: self.dataframe['probabilities'] = self.dataframe['probabilities'].apply( lambda x: np.array([float(i) for i in x.split(',')]) ) def get_padded_probabilities(self, idx): ''' Pads probabilities to max_length if they're too short; truncate them if they're too long ''' no_mask_value = int(-1e9) # will be used to make sure CLS and PAD aren't masked # add a no-mask slot for prob = np.concatenate(( np.array([no_mask_value]), self.dataframe.iloc[idx]['probabilities'] ) ) # Pad with no_mask_value for everything after the probability vector ends if len(prob) < self.max_length: return np.pad( prob, (0, self.max_length - len(prob)), 'constant', constant_values=(0,no_mask_value)) # If it's too long, we need to truncate, but we also need to change the last token to an . prob = prob[0:self.max_length-1] prob = np.concatenate(( prob, np.array([no_mask_value]), ) ) return prob def __getitem__(self, idx): sequence = self.dataframe.iloc[idx]['sequence'] probability = self.get_padded_probabilities(idx) # extract them inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) # does this have to be 512? inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} # Remove batch dimension return inputs, probability def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True): """ Creates a DataLoader for the dataset. Args: data_path (str): Path to the CSV file (train, val, or test). batch_size (int): Batch size. shuffle (bool): Whether to shuffle the data. tokenizer (Tokenizer): tokenizer object for data tokenization Returns: DataLoader: DataLoader object. """ dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''): log_update(f'\nBuilt train, validation, and test dataloders') log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}") log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}") log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}") dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader) if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)") else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}") # write length ranges to a text file if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')): os.mkdir(f'{checkpoint_dir}/batch_diversity') max_length_violators = [] for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items(): max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length) if max_length_followed == False: max_length_violators.append(name) with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f: for tup in length_ranges: f.write(f'{tup[0]}\t{tup[1]}\n') if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}") else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}") def check_dataloader_overlap(train_loader, val_loader, test_loader): """ Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val Returns: """ train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique()) val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique()) test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique()) tr_va = len(train_protein_seqs.intersection(val_protein_seqs)) tr_te = len(train_protein_seqs.intersection(test_protein_seqs)) va_te = len(val_protein_seqs.intersection(test_protein_seqs)) overlaps = [] if tr_va==tr_te==va_te==0: return overlaps # data is clean else: if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}") if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}") if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}") return overlaps def check_max_length_and_length_diversity(dataloader, max_length): """ Check if all sequences in the DataLoader conform to the specified max_length, and return the sequence length ranges within each batch. Args: dataloader (DataLoader): The DataLoader object to check. max_length (int): The maximum allowed sequence length. Returns: bool: True if all sequences are within the max_length, False otherwise. list: A list of tuples representing the min and max sequence lengths in each batch. """ length_ranges = [] all_within_max_length = True for batch_idx, (inputs, _) in enumerate(dataloader): input_ids = inputs['input_ids'] # Calculate the actual lengths of sequences in this batch actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1) min_length = actual_lengths.min().item() max_length_in_batch = actual_lengths.max().item() # Check for max length violation if max_length_in_batch > max_length: #print(f"Error: Sequence exceeds max_length of {max_length} at batch {batch_idx + 1}. Max length found: {max_length_in_batch}") all_within_max_length = False # Store the length range for this batch length_ranges.append((min_length, max_length_in_batch)) #print(f"All sequences in the DataLoader conform to the max_length of {max_length}.") if all_within_max_length else None #print(f"Sequence length ranges per batch: {length_ranges}") return all_within_max_length, length_ranges def check_max_length_in_dataloader(dataloader, max_length): """ Check if all sequences in the DataLoader conform to the specified max_length. Args: dataloader (DataLoader): The DataLoader object to check. max_length (int): The maximum allowed sequence length. Returns: bool: True if all sequences are within the max_length, False otherwise. """ for batch_idx, (inputs, _) in enumerate(dataloader): input_ids = inputs['input_ids'] # Check if any sequence length exceeds max_length if input_ids.size(1) > max_length: return False return True def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15): """ """ #print('the batch sample method was called!') labels = inputs["input_ids"].detach().clone() labels[labels != tokenizer.mask_token_id] = -100 # Set labels for unmasked tokens to -100 # Iterate over each sequence and its corresponding probabilities in the batch for idx in range(inputs["input_ids"].size(0)): # Assuming the first dimension is batch size input_ids = inputs["input_ids"][idx] prob = probabilities[idx] cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item() eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item() seq_length = eos_token_index - (cls_token_index+1) assert prob.shape[0] == input_ids.shape[0] # Normalize probabilities using softmax prob = softmax(prob, dim=0).cpu().numpy() # move to CPU for numpy assert 1 - sum(prob) < 1e-6 # Calculate the number of tokens to mask num_tokens_to_mask = int(mask_percentage * seq_length) # Choose indices to mask based on the probability distribution mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob) attention_mask_1_indices = np.arange(0, eos_token_index+1, 1) # Mask the selected indices and set the corresponding labels labels[idx, mask_indices] = input_ids[mask_indices].detach().clone() input_ids[mask_indices] = tokenizer.mask_token_id inputs["attention_mask"][idx] = torch.zeros_like(input_ids) inputs["attention_mask"][idx][attention_mask_1_indices] = 1 # just added this to try and update the attention mask.... # Update the input_ids in the inputs dictionary inputs["input_ids"][idx] = input_ids inputs["labels"] = labels return inputs