svincoff's picture
fixed READMEs and added IDR Prediction benchmark
e048d40
#datamodule implementation
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningDataModule
#torch dataset, matching embeddings to sequence
class IDRProtDataset(torch.utils.data.Dataset):
def __init__(self, combined_embeddings, data, idr_property):
super().__init__()
self.combined_embeddings = combined_embeddings
self.data = data
self.idr_property = idr_property # 'asph','scaled_re','scaled_rg','scaling_exp'
def __len__(self):
return len(self.data)
def __getitem__(self, index):
prot, dimension = self.data.loc[index][['Sequence','Value']]
prot_embedding = self.combined_embeddings[prot]
return_dict = {
"Protein": prot,
"Dimension": torch.tensor(dimension),
"Protein Input": torch.tensor(prot_embedding.astype(float), dtype = torch.float32),
}
return return_dict
class IDRDataModule(LightningDataModule):
def __init__(self, train_df, val_df, test_df, combined_embeddings, idr_property, batch_size=64):
super().__init__()
self.batch_size = batch_size
self.train_df = train_df
self.val_df = val_df
self.test_df = test_df
self.combined_embeddings = combined_embeddings
self.idr_property = idr_property
def prepare_data(self):
self.train_dataset = self.train_df
self.val_dataset = self.val_df
self.test_dataset = self.test_df
def setup(self, stage=None):
self.train_dataset = IDRProtDataset(self.combined_embeddings, self.train_dataset.reset_index(), self.idr_property)
self.val_dataset = IDRProtDataset(self.combined_embeddings, self.val_dataset.reset_index(), self.idr_property)
self.test_dataset = IDRProtDataset(self.combined_embeddings, self.test_dataset.reset_index(), self.idr_property)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=15)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=15)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=15)