|
import pytorch_lightning as pl |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import pickle |
|
import torch |
|
import esm |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import random |
|
import io |
|
|
|
from transformers import EsmModel, EsmTokenizer, EsmConfig, AutoTokenizer |
|
from sklearn.metrics import roc_auc_score |
|
|
|
|
|
class ProteinMLPOneHot(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.network = nn.Sequential( |
|
nn.Linear(20, 8), |
|
nn.ReLU(), |
|
nn.LayerNorm(8), |
|
nn.Dropout(0.2), |
|
nn.Linear(8, 4), |
|
nn.ReLU(), |
|
nn.LayerNorm(4), |
|
nn.Dropout(0.2), |
|
nn.Linear(4, 1) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.network(x) |
|
return x |
|
|
|
def training_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
loss = F.mse_loss(y_hat, y) |
|
self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
val_loss = F.mse_loss(y_hat, y) |
|
self.log('val_loss', val_loss, on_epoch=True, prog_bar=True, logger=True) |
|
return val_loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
test_loss = F.mse_loss(y_hat, y) |
|
self.log('test_loss', test_loss, on_epoch=True, prog_bar=True, logger=True) |
|
return test_loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=0.0003) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProteinMLPESM(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.network = nn.Sequential( |
|
nn.Linear(1280, 640), |
|
nn.ReLU(), |
|
nn.LayerNorm(640), |
|
nn.Dropout(0.2), |
|
nn.Linear(640, 320), |
|
nn.ReLU(), |
|
nn.LayerNorm(320), |
|
nn.Dropout(0.2), |
|
nn.Linear(320, 1) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.network(x) |
|
return x |
|
|
|
def training_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
loss = F.mse_loss(y_hat, y) |
|
self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
val_loss = F.mse_loss(y_hat, y) |
|
self.log('val_loss', val_loss, on_epoch=True, prog_bar=True, logger=True) |
|
return val_loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
x, y = batch['Protein Input'], batch['Dimension'].float() |
|
y_hat = self(x).squeeze(-1) |
|
test_loss = F.mse_loss(y_hat, y) |
|
self.log('test_loss', test_loss, on_epoch=True, prog_bar=True, logger=True) |
|
return test_loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=0.0003) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LossTrackerCallback(pl.Callback): |
|
def __init__(self): |
|
self.train_losses = [] |
|
self.val_losses = [] |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
|
|
train_loss = trainer.callback_metrics.get('train_loss') |
|
if train_loss: |
|
self.train_losses.append(train_loss.item()) |
|
|
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
|
|
val_loss = trainer.callback_metrics.get('val_loss') |
|
if val_loss: |
|
self.val_losses.append(val_loss.item()) |
|
|