File size: 2,357 Bytes
e048d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#datamodule implementation
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningDataModule

#torch dataset, matching embeddings to sequence

class IDRProtDataset(torch.utils.data.Dataset):
   def __init__(self, combined_embeddings, data, idr_property):
    super().__init__()
    self.combined_embeddings = combined_embeddings
    self.data = data
    self.idr_property = idr_property # 'asph','scaled_re','scaled_rg','scaling_exp'

   def __len__(self):
      return len(self.data)

   def __getitem__(self, index):
    prot, dimension = self.data.loc[index][['Sequence','Value']]
    prot_embedding = self.combined_embeddings[prot]

    return_dict = {
        "Protein": prot,
        "Dimension": torch.tensor(dimension),
        "Protein Input": torch.tensor(prot_embedding.astype(float), dtype = torch.float32),
    }

    return return_dict


class IDRDataModule(LightningDataModule):
    def __init__(self, train_df, val_df, test_df, combined_embeddings, idr_property, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.combined_embeddings = combined_embeddings
        self.idr_property = idr_property

    def prepare_data(self):
        self.train_dataset = self.train_df
        self.val_dataset = self.val_df
        self.test_dataset = self.test_df

    def setup(self, stage=None):
        self.train_dataset = IDRProtDataset(self.combined_embeddings, self.train_dataset.reset_index(), self.idr_property)
        self.val_dataset = IDRProtDataset(self.combined_embeddings, self.val_dataset.reset_index(), self.idr_property)
        self.test_dataset = IDRProtDataset(self.combined_embeddings, self.test_dataset.reset_index(), self.idr_property)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=15)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=15)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=15)