#datamodule implementation import pytorch_lightning as pl from torch.utils.data import DataLoader import torch from torch.utils.data import DataLoader, random_split from pytorch_lightning import LightningDataModule #torch dataset, matching embeddings to sequence class IDRProtDataset(torch.utils.data.Dataset): def __init__(self, combined_embeddings, data, idr_property): super().__init__() self.combined_embeddings = combined_embeddings self.data = data self.idr_property = idr_property # 'asph','scaled_re','scaled_rg','scaling_exp' def __len__(self): return len(self.data) def __getitem__(self, index): prot, dimension = self.data.loc[index][['Sequence','Value']] prot_embedding = self.combined_embeddings[prot] return_dict = { "Protein": prot, "Dimension": torch.tensor(dimension), "Protein Input": torch.tensor(prot_embedding.astype(float), dtype = torch.float32), } return return_dict class IDRDataModule(LightningDataModule): def __init__(self, train_df, val_df, test_df, combined_embeddings, idr_property, batch_size=64): super().__init__() self.batch_size = batch_size self.train_df = train_df self.val_df = val_df self.test_df = test_df self.combined_embeddings = combined_embeddings self.idr_property = idr_property def prepare_data(self): self.train_dataset = self.train_df self.val_dataset = self.val_df self.test_dataset = self.test_df def setup(self, stage=None): self.train_dataset = IDRProtDataset(self.combined_embeddings, self.train_dataset.reset_index(), self.idr_property) self.val_dataset = IDRProtDataset(self.combined_embeddings, self.val_dataset.reset_index(), self.idr_property) self.test_dataset = IDRProtDataset(self.combined_embeddings, self.test_dataset.reset_index(), self.idr_property) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=15) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=15) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=15)