import pandas as pd import torch import pickle import os from torch.utils.data import DataLoader, Dataset from fuson_plm.utils.logging import log_update # Dataset class that loads embeddings and labels # Write it to either use a cached location for embeddings, or be able to make them on the spot def custom_collate_fn(batch): """ Custom collate function to handle batches with strings and tensors. Args: batch (list): List of tuples returned by __getitem__. Returns: tuple: (sequences, embeddings, labels) - sequences: List of strings - embeddings: Tensor of shape (batch_size, embedding_dim) - labels: Tensor of shape (batch_size, sequence_length) """ sequences, embeddings, labels = zip(*batch) # Unzip the batch into separate tuples # Stack embeddings and labels into tensors embeddings = torch.stack(embeddings, dim=0) # Shape: (batch_size, embedding_dim) labels = torch.stack(labels, dim=0) # Shape: (batch_size, sequence_length) # Convert sequences from tuple to list sequences = list(sequences) return sequences, embeddings, labels class DisorderDataset(Dataset): def __init__(self, csv_file_path, cached_embeddings_path=None, max_length=4405): super(DisorderDataset, self).__init__() self.dataset = pd.read_csv(csv_file_path)#.head(5) self.cached_embeddings_path = cached_embeddings_path # initialize embeddings self.embeddings = self.__retrieve_embeddings__() def __len__(self): return len(self.dataset) def __retrieve_embeddings__(self): try: with open(self.cached_embeddings_path,"rb") as f: # Load all embeddings embeddings = pickle.load(f) except: raise Exception("Error: failed to load embeddings") # Keep only embeddings for the sequences in self.dataset seqs = self.dataset['Sequence'].tolist() embeddings = {k:v for k,v in embeddings.items() if k in seqs} return embeddings def __getitem__(self, idx): sequence = self.dataset.iloc[idx]['Sequence'] embedding = self.embeddings[sequence] embedding = torch.tensor(embedding, dtype=torch.float32) # Convert string representations of labels to floats label_str = self.dataset.iloc[idx]['Label'] #label_str = label_str[1:-1] why this line??? labels = list(map(int, label_str)) labels = torch.tensor(labels, dtype=torch.float) assert len(labels)==len(sequence) return sequence, embedding, labels def get_dataloader(data_path, cached_embeddings_path, max_length=4405, batch_size=1, shuffle=True): """ Creates a DataLoader for the dataset. Args: data_path (str): Path to the CSV file (train, val, or test). batch_size (int): Batch size. shuffle (bool): Whether to shuffle the data. tokenizer (Tokenizer): tokenizer object for data tokenization Returns: DataLoader: DataLoader object. """ dataset = DisorderDataset(data_path, cached_embeddings_path=cached_embeddings_path, max_length=max_length) return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle) def check_dataloaders(train_loader, test_loader, max_length=512, checkpoint_dir=''): log_update(f'\nBuilt train and test dataloders') log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}") log_update(f"\tNumber of sequences in the Testing DataLoader: {len(test_loader.dataset)}") dataloader_overlaps = check_dataloader_overlap(train_loader, test_loader) if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)") else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}") # write length ranges to a text file if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')): os.mkdir(f'{checkpoint_dir}/batch_diversity') max_length_violators = [] for name, dataloader in {'train':train_loader, 'test':test_loader}.items(): max_length_followed = check_max_length(dataloader, max_length) if max_length_followed == False: max_length_violators.append(name) if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}") else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}") def check_dataloader_overlap(train_loader, test_loader): train_seqs = set() test_seqs = set() for batch_idx, (sequences, _, _) in enumerate(train_loader): train_seqs.add(sequences[0]) for batch_idx, (sequences, _, _) in enumerate(test_loader): test_seqs.add(sequences[0]) return train_seqs.intersection(test_seqs) def check_max_length(dataloader, max_length): for batch_idx, (sequences, _, _) in enumerate(dataloader): if len(sequences[0]) > max_length: return False return True