PlantGFM-Gene-generation / plantgfm /modeling_segmentgfm.py
hu-lab's picture
Upload 4 files
2f5ce58 verified
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from transformers import PreTrainedModel, PretrainedConfig
from .modeling_plantgfm import PlantGFMForCausalLM
class SegmentGFMConfig(PretrainedConfig):
model_type = "segmentglm"
def __init__(
self,
pre_trained_path = None,
unet_embd_dim = [1024,1536,2560,4096],
unet_kernel_size = 3,
unet_dilation = [6,12,24],
unet_padding = [6,12,24],
unet_layer_dropout = 0.25,
out_embd_dim = 256,
out_k = 1,
**kwargs,
):
self.pre_trained_path = pre_trained_path
self.unet_embd_dim = unet_embd_dim
self.unet_kernel_size = unet_kernel_size
self.unet_dilation = unet_dilation
self.unet_padding = unet_padding
self.unet_layer_dropout = unet_layer_dropout
self.out_embd_dim = out_embd_dim
self.out_k = out_k
super().__init__(**kwargs)
@classmethod
def from_original_config(cls, config_path, **kwargs):
with open(config_path, "r") as f:
config = json.load(f)
pre_trained_path = config["pre_trained_path"]
unet_embd_dim = config["unet_embd_dim"]
unet_kernel_size = config["unet_kernel_size"]
unet_dilation = config["unet_dilation"]
unet_padding = config["unet_padding"]
unet_layer_dropout = config["unet_layer_dropout"]
out_embd_dim = config["out_embd_dim"]
out_k = config["out_k"]
return cls(
pre_trained_path = pre_trained_path,
unet_embd_dim = unet_embd_dim,
unet_kernel_size = unet_kernel_size,
unet_dilation = unet_dilation,
unet_padding = unet_padding,
unet_layer_dropout = unet_layer_dropout,
out_embd_dim = out_embd_dim,
out_k = out_k,
**kwargs
)
class PlantGFMEmbd(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
glm_model = PlantGFMForCausalLM.from_pretrained(self.config.pre_trained_path)
self.glm_decoder = glm_model.get_decoder()
def forward(self, input_ids):
embd = self.glm_decoder(input_ids, return_dict=True)["last_hidden_state"]
embd = embd[:,1:-1,:].transpose(1,2)
return embd
class DilatedConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation, dropout_rate):
super().__init__()
self.dilated_conv = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
),
nn.Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
),
nn.SiLU(),
nn.Dropout1d(p=dropout_rate),
)
def forward(self, x: torch.Tensor):
return self.dilated_conv(x)
class DilatedUNetHead(nn.Module):
def __init__(
self,
embd_dim,
kernel_size,
padding,
dilation,
layer_dropout=0.25,
out_embd_dim=256,
out_k=1,
):
super().__init__()
self.out_k = out_k
self.down_conv1 = DilatedConvLayer(embd_dim[0], embd_dim[1], kernel_size, padding[0], dilation[0], layer_dropout)
self.down_conv2 = DilatedConvLayer(embd_dim[1], embd_dim[2], kernel_size, padding[1], dilation[1], layer_dropout)
self.down_conv3 = DilatedConvLayer(embd_dim[2], embd_dim[3], kernel_size, padding[2], dilation[2], layer_dropout)
self.up_trans1 = nn.ConvTranspose1d(embd_dim[3], embd_dim[2], kernel_size=2, stride=2, groups=64)
self.up_trans2 = nn.ConvTranspose1d(embd_dim[2], embd_dim[1], kernel_size=2, stride=2, groups=64)
self.up_conv1 = DilatedConvLayer(2*embd_dim[2], embd_dim[2], kernel_size, padding[1], dilation[1], layer_dropout)
self.up_conv2 = DilatedConvLayer(2*embd_dim[1], out_embd_dim, kernel_size, padding[0], dilation[0], layer_dropout)
self.output = nn.Conv1d(in_channels=out_embd_dim, out_channels=out_k, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor):
x = self.down_conv1(x)
t1 = x
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = self.down_conv2(x)
t3 = x
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = self.down_conv3(x)
x = self.up_trans1(x)
x = torch.cat([x, t3], 1)
x = self.up_conv1(x)
x = self.up_trans2(x)
x = torch.cat([x, t1], 1)
x = self.up_conv2(x)
x = self.output(x)
if self.out_k == 1:
return x.squeeze(1) # when out_k==1 return target (bsz, L)
return x # return target (bsz, out_k, L)
class IoULoss(nn.Module):
def __init__(self, smooth=1e-6):
super(IoULoss, self).__init__()
self.smooth = smooth
def forward(self, inputs, targets):
inputs = torch.sigmoid(inputs)
inputs = inputs.view(inputs.size(0), -1) # (batch_size, *)
targets = targets.view(targets.size(0), -1) # (batch_size, *)
intersection = (inputs * targets).sum(dim=1)
total = (inputs + targets).sum(dim=1)
union = total - intersection
iou = (intersection + self.smooth) / (union + self.smooth)
return 1 - iou.mean()
class CombinedLoss(nn.Module):
def __init__(self, smooth=1e-6, bce_weight=0.5, iou_weight=0.5):
super(CombinedLoss, self).__init__()
self.bce_loss = nn.BCEWithLogitsLoss()
self.iou_loss = IoULoss(smooth=smooth)
self.bce_weight = bce_weight
self.iou_weight = iou_weight
def forward(self, inputs, targets):
bce = self.bce_loss(inputs, targets)
iou = self.iou_loss(inputs, targets)
combined_loss = self.bce_weight * bce + self.iou_weight * iou
return combined_loss
class SegmentGFMModel(PreTrainedModel):
config_class = SegmentGFMConfig
_no_split_modules = ["DilatedUNetHead"]
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
self.glm_embd = PlantGFMEmbd(config=config)
self.unet_head = DilatedUNetHead(
self.config.unet_embd_dim,
self.config.unet_kernel_size,
self.config.unet_padding,
self.config.unet_dilation,
self.config.unet_layer_dropout,
self.config.out_embd_dim,
self.config.out_k
)
self.loss_funct = CombinedLoss(bce_weight=0.5, iou_weight=0.5)
def forward(self, input_ids: torch.LongTensor = None, labels: Optional[torch.FloatTensor] = None):
x = self.glm_embd(input_ids)
x = self.unet_head(x)
if labels is None:
return x
return {
"loss": self.loss_funct(x, labels),
"predictions": x
}