|
|
|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader |
|
import torch |
|
from torch.utils.data import DataLoader, random_split |
|
from pytorch_lightning import LightningDataModule |
|
|
|
|
|
|
|
class IDRProtDataset(torch.utils.data.Dataset): |
|
def __init__(self, combined_embeddings, data, idr_property): |
|
super().__init__() |
|
self.combined_embeddings = combined_embeddings |
|
self.data = data |
|
self.idr_property = idr_property |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
prot, dimension = self.data.loc[index][['Sequence','Value']] |
|
prot_embedding = self.combined_embeddings[prot] |
|
|
|
return_dict = { |
|
"Protein": prot, |
|
"Dimension": torch.tensor(dimension), |
|
"Protein Input": torch.tensor(prot_embedding.astype(float), dtype = torch.float32), |
|
} |
|
|
|
return return_dict |
|
|
|
|
|
class IDRDataModule(LightningDataModule): |
|
def __init__(self, train_df, val_df, test_df, combined_embeddings, idr_property, batch_size=64): |
|
super().__init__() |
|
self.batch_size = batch_size |
|
self.train_df = train_df |
|
self.val_df = val_df |
|
self.test_df = test_df |
|
self.combined_embeddings = combined_embeddings |
|
self.idr_property = idr_property |
|
|
|
def prepare_data(self): |
|
self.train_dataset = self.train_df |
|
self.val_dataset = self.val_df |
|
self.test_dataset = self.test_df |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = IDRProtDataset(self.combined_embeddings, self.train_dataset.reset_index(), self.idr_property) |
|
self.val_dataset = IDRProtDataset(self.combined_embeddings, self.val_dataset.reset_index(), self.idr_property) |
|
self.test_dataset = IDRProtDataset(self.combined_embeddings, self.test_dataset.reset_index(), self.idr_property) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=15) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=15) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=15) |