slslslrhfem
fix some model
508344f
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional
import numpy as np
from pathlib import Path
import math
# from deepspeed.ops.adam import FusedAdam # ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ๋กœ ๋น„ํ™œ์„ฑํ™”
class MusicAudioClassifier(pl.LightningModule):
def __init__(self,
input_dim: int,
hidden_dim: int = 256,
learning_rate: float = 1e-4,
emb_model: Optional[nn.Module] = None,
backbone: str = 'segment_transformer',
num_classes: int = 2):
super().__init__()
self.save_hyperparameters()
if backbone == 'segment_transformer':
self.model = SegmentTransformer(
input_dim=input_dim,
hidden_dim=hidden_dim,
num_classes=num_classes,
mode = 'both'
)
elif backbone == 'fusion_segment_transformer':
self.model = FusionSegmentTransformer(
input_dim=input_dim,
hidden_dim=hidden_dim,
num_classes=num_classes
)
# elif backbone == 'guided_segment_transformer':
# self.model = GuidedSegmentTransformer(
# input_dim=input_dim,
# hidden_dim=hidden_dim,
# num_classes=num_classes
# )
def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T]
embeddings=x
if embeddings.dim() == 3:
pooled_features = embeddings.mean(dim=1) # transformer
else:
pooled_features = embeddings # CCV..? no need to pooling
return pooled_features.view(B, S, -1) # [B, S, emb_dim]
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self._process_audio_batch(x) # ์ด๊ฑธ freezeํ•˜๊ณ  ์“ฐ๋Š”๊ฒŒ ์‚ฌ์‹ค์ƒ ์œ—๋ฒ„์ „์ž„
x = x.half()
return self.model(x, mask)
def _compute_loss_and_probs(self, y_hat: torch.Tensor, y: torch.Tensor):
"""Compute loss and probabilities based on number of classes"""
if y_hat.size(0) == 1:
y_hat_flat = y_hat.flatten()
y_flat = y.flatten()
else:
y_hat_flat = y_hat.squeeze() if self.num_classes == 2 else y_hat
y_flat = y
if self.num_classes == 2:
loss = F.binary_cross_entropy_with_logits(y_hat_flat, y_flat.float())
probs = torch.sigmoid(y_hat_flat)
preds = (probs > 0.5).long()
else:
loss = F.cross_entropy(y_hat_flat, y_flat.long())
probs = F.softmax(y_hat_flat, dim=-1)
preds = torch.argmax(y_hat_flat, dim=-1)
return loss, probs, preds, y_flat.long()
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
x, y, mask = batch
x = x.half()
y_hat = self(x, mask)
loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
# ๊ฐ„๋‹จํ•œ ๋ฐฐ์น˜ ์†์‹ค๋งŒ ๋กœ๊น… (step ์ˆ˜์ค€)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
# ์ „์ฒด ์—ํญ์— ๋Œ€ํ•œ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ์˜ˆ์ธก๊ณผ ์‹ค์ œ๊ฐ’ ์ €์žฅ
if self.num_classes == 2:
self.training_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
else:
self.training_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None:
x, y, mask = batch
x = x.half()
y_hat = self(x, mask)
loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
# ๊ฐ„๋‹จํ•œ ๋ฐฐ์น˜ ์†์‹ค๋งŒ ๋กœ๊น… (step ์ˆ˜์ค€)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
# ์ „์ฒด ์—ํญ์— ๋Œ€ํ•œ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ์˜ˆ์ธก๊ณผ ์‹ค์ œ๊ฐ’ ์ €์žฅ
if self.num_classes == 2:
self.validation_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
else:
self.validation_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None:
x, y, mask = batch
x = x.half()
y_hat = self(x, mask)
loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
# ๊ฐ„๋‹จํ•œ ๋ฐฐ์น˜ ์†์‹ค๋งŒ ๋กœ๊น… (step ์ˆ˜์ค€)
self.log('test_loss', loss, on_epoch=True, prog_bar=True)
# ์ „์ฒด ์—ํญ์— ๋Œ€ํ•œ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ์˜ˆ์ธก๊ณผ ์‹ค์ œ๊ฐ’ ์ €์žฅ
if self.num_classes == 2:
self.test_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
else:
self.test_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
def on_train_epoch_start(self):
# ์—ํญ ์‹œ์ž‘ ์‹œ ๊ฒฐ๊ณผ ์ €์žฅ์šฉ ๋ฆฌ์ŠคํŠธ ์ดˆ๊ธฐํ™”
self.training_step_outputs = []
def on_validation_epoch_start(self):
# ์—ํญ ์‹œ์ž‘ ์‹œ ๊ฒฐ๊ณผ ์ €์žฅ์šฉ ๋ฆฌ์ŠคํŠธ ์ดˆ๊ธฐํ™”
self.validation_step_outputs = []
def on_test_epoch_start(self):
# ์—ํญ ์‹œ์ž‘ ์‹œ ๊ฒฐ๊ณผ ์ €์žฅ์šฉ ๋ฆฌ์ŠคํŠธ ์ดˆ๊ธฐํ™”
self.test_step_outputs = []
def _compute_binary_metrics(self, outputs, prefix):
"""Binary classification metrics computation"""
all_preds = torch.cat([x['preds'] for x in outputs])
all_targets = torch.cat([x['targets'] for x in outputs])
binary_preds = torch.cat([x['binary_preds'] for x in outputs])
# ์ •ํ™•๋„ ๊ณ„์‚ฐ
acc = (binary_preds == all_targets).float().mean()
# ํ˜ผ๋™ ํ–‰๋ ฌ ์š”์†Œ ๊ณ„์‚ฐ
tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float()
fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float()
tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float()
fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float()
# ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device)
# ๋กœ๊น…
self.log(f'{prefix}_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True)
self.log(f'{prefix}_precision', precision, on_epoch=True, sync_dist=True)
self.log(f'{prefix}_recall', recall, on_epoch=True, sync_dist=True)
self.log(f'{prefix}_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True)
self.log(f'{prefix}_specificity', specificity, on_epoch=True, sync_dist=True)
if prefix in ['val', 'test']:
# ROC-AUC ๊ณ„์‚ฐ (๊ฐ„๋‹จํ•œ ๊ทผ์‚ฌ)
sorted_indices = torch.argsort(all_preds, descending=True)
sorted_targets = all_targets[sorted_indices]
n_pos = torch.sum(all_targets)
n_neg = len(all_targets) - n_pos
if n_pos > 0 and n_neg > 0:
tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos
fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg
width = fpr_curve[1:] - fpr_curve[:-1]
height = (tpr_curve[1:] + tpr_curve[:-1]) / 2
auc_approx = torch.sum(width * height)
self.log(f'{prefix}_auc', auc_approx, on_epoch=True, sync_dist=True)
if prefix == 'test':
balanced_acc = (recall + specificity) / 2
self.log('test_balanced_acc', balanced_acc, on_epoch=True)
def _compute_multiclass_metrics(self, outputs, prefix):
"""Multi-class classification metrics computation"""
all_probs = torch.cat([x['probs'] for x in outputs])
all_preds = torch.cat([x['preds'] for x in outputs])
all_targets = torch.cat([x['targets'] for x in outputs])
# ์ „์ฒด ์ •ํ™•๋„
acc = (all_preds == all_targets).float().mean()
self.log(f'{prefix}_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True)
# ํด๋ž˜์Šค๋ณ„ ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
for class_idx in range(self.num_classes):
# ๊ฐ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ์ด์ง„ ๋ถ„๋ฅ˜ ๋ฉ”ํŠธ๋ฆญ
class_targets = (all_targets == class_idx).long()
class_preds = (all_preds == class_idx).long()
tp = torch.sum((class_preds == 1) & (class_targets == 1)).float()
fp = torch.sum((class_preds == 1) & (class_targets == 0)).float()
tn = torch.sum((class_preds == 0) & (class_targets == 0)).float()
fn = torch.sum((class_preds == 0) & (class_targets == 1)).float()
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
self.log(f'{prefix}_class_{class_idx}_precision', precision, on_epoch=True)
self.log(f'{prefix}_class_{class_idx}_recall', recall, on_epoch=True)
self.log(f'{prefix}_class_{class_idx}_f1', f1, on_epoch=True)
# ๋งคํฌ๋กœ ํ‰๊ท  F1 ์Šค์ฝ”์–ด
class_f1_scores = []
for class_idx in range(self.num_classes):
class_targets = (all_targets == class_idx).long()
class_preds = (all_preds == class_idx).long()
tp = torch.sum((class_preds == 1) & (class_targets == 1)).float()
fp = torch.sum((class_preds == 1) & (class_targets == 0)).float()
fn = torch.sum((class_preds == 0) & (class_targets == 1)).float()
precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
class_f1_scores.append(f1)
macro_f1 = torch.stack(class_f1_scores).mean()
self.log(f'{prefix}_macro_f1', macro_f1, on_epoch=True, prog_bar=True, sync_dist=True)
def on_train_epoch_end(self):
if not hasattr(self, 'training_step_outputs') or not self.training_step_outputs:
return
if self.num_classes == 2:
self._compute_binary_metrics(self.training_step_outputs, 'train')
else:
self._compute_multiclass_metrics(self.training_step_outputs, 'train')
def on_validation_epoch_end(self):
if not hasattr(self, 'validation_step_outputs') or not self.validation_step_outputs:
return
if self.num_classes == 2:
self._compute_binary_metrics(self.validation_step_outputs, 'val')
else:
self._compute_multiclass_metrics(self.validation_step_outputs, 'val')
def on_test_epoch_end(self):
if not hasattr(self, 'test_step_outputs') or not self.test_step_outputs:
return
if self.num_classes == 2:
self._compute_binary_metrics(self.test_step_outputs, 'test')
else:
self._compute_multiclass_metrics(self.test_step_outputs, 'test')
def configure_optimizers(self):
# FusedAdam ๋Œ€์‹  ์ผ๋ฐ˜ AdamW ์‚ฌ์šฉ (GLIBC ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ ํ•ด๊ฒฐ)
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=100, # Adjust based on your training epochs
eta_min=1e-6
)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_loss',
}
def pad_sequence_with_mask(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Collate function for DataLoader that creates padded sequences and attention masks with fixed length (48)."""
embeddings, labels = zip(*batch)
fixed_len = 48 # ๊ณ ์ • ๊ธธ์ด
batch_size = len(embeddings)
feat_dim = embeddings[0].shape[-1]
padded = torch.zeros((batch_size, fixed_len, feat_dim)) # ๊ณ ์ • ๊ธธ์ด๋กœ ํŒจ๋”ฉ๋œ ํ…์„œ
mask = torch.ones((batch_size, fixed_len), dtype=torch.bool) # True๋Š” padding์„ ์˜๋ฏธ
for i, emb in enumerate(embeddings):
length = emb.shape[0]
# ๊ธธ์ด๊ฐ€ ๊ณ ์ • ๊ธธ์ด๋ณด๋‹ค ๊ธธ๋ฉด ์ž๋ฅด๊ณ , ์งง์œผ๋ฉด ํŒจ๋”ฉ
if length > fixed_len:
padded[i, :] = emb[:fixed_len] # fixed_len๋ณด๋‹ค ๊ธด ๋ถ€๋ถ„์„ ์ž˜๋ผ์„œ ์ฑ„์šด๋‹ค.
mask[i, :] = False
else:
padded[i, :length] = emb # ์‹ค์ œ ๋ฐ์ดํ„ฐ ๊ธธ์ด์— ๋งž๊ฒŒ ์ฑ„์šด๋‹ค.
mask[i, :length] = False # ํŒจ๋”ฉ์ด ์•„๋‹Œ ๋ถ€๋ถ„์€ False๋กœ ์„ค์ •
return padded, torch.tensor(labels), mask
class SegmentTransformer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int = 256,
num_heads: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
max_sequence_length: int = 1000,
mode: str = 'both',
share_parameter: bool = False,
num_classes: int = 2):
super().__init__()
# Original sequence processing
self.input_projection = nn.Linear(input_dim, hidden_dim)
self.mode = mode
self.share_parameter = share_parameter
self.num_classes = num_classes
# Positional encoding
position = torch.arange(max_sequence_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pos_encoding', pos_encoding)
# Transformer for original sequence
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.sim_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Self-similarity stream processing
self.similarity_projection = nn.Sequential(
nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.Dropout(dropout)
)
# Transformer for similarity stream
self.similarity_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Final classification head
self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim
# Output dimension based on number of classes
output_dim = 1 if num_classes == 2 else num_classes
self.classification_head = nn.Sequential(
nn.Linear(self.classification_head_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim)
)
def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# 1. Process original sequence
x = x.half()
x1 = self.input_projection(x)
x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0)
x1 = self.transformer(x1, src_key_padding_mask=padding_mask) # padding_mask ์‚ฌ์šฉ
# 2. Calculate and process self-similarity
x_expanded = x.unsqueeze(2)
x_transposed = x.unsqueeze(1)
distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
similarity_matrix = torch.exp(-distances) # (batch_size, seq_len, seq_len)
# ์ž๊ธฐ ์œ ์‚ฌ๋„ ๋งˆ์Šคํฌ ์ƒ์„ฑ ๋ฐ ์ ์šฉ (๊ฐ ์‹œ์ ์— ๋Œ€ํ•œ ๋งˆ์Šคํฌ ๊ฐœ๋ณ„ ์ ์šฉ)
if padding_mask is not None:
similarity_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2) # (batch_size, seq_len, seq_len)
similarity_matrix = similarity_matrix.masked_fill(similarity_mask, 0.0)
# Process similarity matrix row by row using Conv1d
x2 = similarity_matrix.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)
x2 = x2.view(batch_size * seq_len, 1, seq_len) # Reshape for Conv1d
x2 = self.similarity_projection(x2) # (batch_size * seq_len, hidden_dim, seq_len)
x2 = x2.mean(dim=2) # Pool across sequence dimension
x2 = x2.view(batch_size, seq_len, -1) # Reshape back
x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0)
if self.share_parameter:
x2 = self.transformer(x2, src_key_padding_mask=padding_mask)
else:
x2 = self.sim_transformer(x2, src_key_padding_mask=padding_mask) # padding_mask ์‚ฌ์šฉ
# 3. Global average pooling for both streams
if padding_mask is not None:
mask_expanded = (~padding_mask).float().unsqueeze(-1)
x1 = (x1 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
x2 = (x2 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
else:
x1 = x1.mean(dim=1)
x2 = x2.mean(dim=1)
# 4. Combine both streams and classify
if self.mode == 'only_emb':
x = x1
elif self.mode == 'only_structure':
x = x2
elif self.mode == 'both':
x = torch.cat([x1, x2], dim=-1)
x= x.half()
return self.classification_head(x)
class PairwiseGuidedTransformer(nn.Module):
"""Pairwise similarity matrix๋ฅผ ํ™œ์šฉํ•œ ๋ฒ”์šฉ transformer layer
Vision: patch๊ฐ„ ์œ ์‚ฌ๋„, NLP: token๊ฐ„ ์œ ์‚ฌ๋„, Audio: segment๊ฐ„ ์œ ์‚ฌ๋„ ๋“ฑ์— ํ™œ์šฉ ๊ฐ€๋Šฅ
"""
def __init__(self, d_model: int, num_heads: int = 8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
# Standard Q, K projections
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
# Pairwise-guided V projection
self.v_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, x, pairwise_matrix, padding_mask=None):
"""
Args:
x: (batch, seq_len, d_model) - sequence embeddings
pairwise_matrix: (batch, seq_len, seq_len) - pairwise similarity/distance matrix
padding_mask: (batch, seq_len) - padding mask
"""
batch_size, seq_len, d_model = x.shape
# Standard Q, K, V
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# Reshape for multi-head
Q = Q.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# Standard attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_model ** 0.5)
# โœ… Combine with pairwise matrix
#pairwise_expanded = pairwise_matrix.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
enhanced_scores = scores# + pairwise_expanded ์ด๊ฑฐ ๋นผ๊ณ  ํ•˜๊ธฐ๋กœ ํ–ˆ์ฃ ?
# Apply padding mask
if padding_mask is not None:
mask_4d = padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.num_heads, seq_len, -1)
enhanced_scores = enhanced_scores.masked_fill(mask_4d, float('-inf'))
# Softmax and apply to V
attn_weights = F.softmax(enhanced_scores, dim=-1)
attended = torch.matmul(attn_weights, V)
# Reshape and project
attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
output = self.output_proj(attended)
return self.norm(x + output)
class MultiScaleAdaptivePooler(nn.Module):
"""Multi-scale adaptive pooling - ๋‹ค์–‘ํ•œ ๋„๋ฉ”์ธ์—์„œ ํ™œ์šฉ ๊ฐ€๋Šฅ"""
def __init__(self, hidden_dim: int, num_heads: int = 8):
super().__init__()
# Attention-based pooling
self.attention_pool = nn.MultiheadAttention(
hidden_dim, num_heads=num_heads, batch_first=True
)
self.query_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
# Complementary pooling strategies
self.max_pool_proj = nn.Linear(hidden_dim, hidden_dim)
self.fusion = nn.Linear(hidden_dim * 3, hidden_dim)
def forward(self, x, padding_mask=None):
"""
Args:
x: (batch, seq_len, hidden_dim) - sequence features
padding_mask: (batch, seq_len) - padding mask
actually not better than avg pooling haha
"""
batch_size = x.size(0)
# 1. Global average pooling
if padding_mask is not None:
mask_expanded = (~padding_mask).float().unsqueeze(-1)
global_avg = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
else:
global_avg = x.mean(dim=1)
# # 2. Global max pooling
# if padding_mask is not None:
# x_masked = x.clone()
# x_masked[padding_mask] = float('-inf')
# global_max = x_masked.max(dim=1)[0]
# else:
# global_max = x.max(dim=1)[0]
# global_max = self.max_pool_proj(global_max)
# # 3. Attention-based pooling
# query = self.query_token.expand(batch_size, -1, -1)
# attn_pooled, _ = self.attention_pool(
# query, x, x,
# key_padding_mask=padding_mask
# )
# attn_pooled = attn_pooled.squeeze(1)
# # 4. Fuse all pooling results
# #combined = torch.cat([global_avg, global_max, attn_pooled], dim=-1)
# #output = self.fusion(combined)
output = global_avg
return output
class GuidedSegmentTransformer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int = 256,
num_heads: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
max_sequence_length: int = 1000,
mode: str = 'only_emb',
share_parameter: bool = False,
num_classes: int = 2):
super().__init__()
# Original sequence processing
self.input_projection = nn.Linear(input_dim, hidden_dim)
self.mode = mode
self.share_parameter = share_parameter
self.num_classes = num_classes
# Positional encoding
position = torch.arange(max_sequence_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pos_encoding', pos_encoding)
# โœ… Pairwise-guided transformer layers (๋ฒ”์šฉ์  ์ด๋ฆ„)
self.pairwise_guided_layers = nn.ModuleList([
PairwiseGuidedTransformer(hidden_dim, num_heads)
for _ in range(num_layers)
])
# Pairwise matrix processing (๊ธฐ์กด similarity processing)
self.pairwise_projection = nn.Sequential(
nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.Dropout(dropout)
)
# โœ… Multi-scale adaptive pooling (๋ฒ”์šฉ์  ์ด๋ฆ„)
self.adaptive_pooler = MultiScaleAdaptivePooler(hidden_dim, num_heads)
# Final classification head
self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim
output_dim = 1 if num_classes == 2 else num_classes
self.classification_head = nn.Sequential(
nn.Linear(self.classification_head_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim)
)
def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# 1. Process sequence
x1 = self.input_projection(x)
x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0)
# 2. Calculate pairwise matrix (can be similarity, distance, correlation, etc.)
x_expanded = x.unsqueeze(2)
x_transposed = x.unsqueeze(1)
distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
pairwise_matrix = torch.exp(-distances) # Convert distance to similarity
# Apply padding mask to pairwise matrix
if padding_mask is not None:
pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
# โœ… Pairwise-guided processing
for layer in self.pairwise_guided_layers:
x1 = layer(x1, pairwise_matrix, padding_mask)
# 3. Process pairwise matrix as separate stream (optional)
if self.mode in ['only_structure', 'both']:
x2 = pairwise_matrix.unsqueeze(1)
x2 = x2.view(batch_size * seq_len, 1, seq_len)
x2 = self.pairwise_projection(x2)
x2 = x2.mean(dim=2)
x2 = x2.view(batch_size, seq_len, -1)
x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0)
# โœ… Multi-scale adaptive pooling
if self.mode == 'only_emb':
x = self.adaptive_pooler(x1, padding_mask)
elif self.mode == 'only_structure':
x = self.adaptive_pooler(x2, padding_mask)
elif self.mode == 'both':
x1_pooled = self.adaptive_pooler(x1, padding_mask)
x2_pooled = self.adaptive_pooler(x2, padding_mask)
x = torch.cat([x1_pooled, x2_pooled], dim=-1)
x = x
return self.classification_head(x)
class CrossModalFusionLayer(nn.Module):
"""Structure์™€ Embedding ์ •๋ณด๋ฅผ ์ ์ง„์ ์œผ๋กœ ์œตํ•ฉ"""
def __init__(self, d_model: int, num_heads: int = 8):
super().__init__()
# Cross-attention: embedding์ด structure๋ฅผ queryํ•˜๊ณ , structure๊ฐ€ embedding์„ query
self.emb_to_struct_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.struct_to_emb_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
# Fusion gate (์–ด๋А ์ •๋ณด๋ฅผ ์–ผ๋งˆ๋‚˜ ๋ฏฟ์„์ง€)
self.fusion_gate = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.Sigmoid()
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, emb_features, struct_features, padding_mask=None):
"""
emb_features: (batch, seq_len, d_model) - ๋ฉ”์ธ embedding ์ •๋ณด
struct_features: (batch, seq_len, d_model) - structure ์ •๋ณด
"""
# 1. Embedding์ด Structure ์ •๋ณด๋ฅผ ์ฐธ์กฐ
emb_enhanced, _ = self.emb_to_struct_attn(
emb_features, struct_features, struct_features,
key_padding_mask=padding_mask
)
emb_enhanced = self.norm1(emb_features + emb_enhanced)
# 2. Structure๊ฐ€ Embedding ์ •๋ณด๋ฅผ ์ฐธ์กฐ
struct_enhanced, _ = self.struct_to_emb_attn(
struct_features, emb_features, emb_features,
key_padding_mask=padding_mask
)
struct_enhanced = self.norm2(struct_features + struct_enhanced)
# 3. Adaptive fusion (๋‘˜ ์ค‘ ์–ด๋А ๊ฒƒ์„ ๋” ๋ฏฟ์„์ง€ ํ•™์Šต)
combined = torch.cat([emb_enhanced, struct_enhanced], dim=-1)
gate_weight = self.fusion_gate(combined) # (batch, seq_len, d_model)
# Gated combination
fused = gate_weight * emb_enhanced + (1 - gate_weight) * struct_enhanced
return fused
class FusionSegmentTransformer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int = 256,
num_heads: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
max_sequence_length: int = 1000,
mode: str = 'both', # ๊ธฐ๋ณธ๊ฐ’์„ both๋กœ
share_parameter: bool = False,
num_classes: int = 2):
super().__init__()
self.input_projection = nn.Linear(input_dim, hidden_dim)
self.mode = mode
self.num_classes = num_classes
# Positional encoding
position = torch.arange(max_sequence_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pos_encoding', pos_encoding)
# โœ… Embedding stream: Pairwise-guided transformer
self.embedding_layers = nn.ModuleList([
PairwiseGuidedTransformer(hidden_dim, num_heads)
for _ in range(num_layers)
])
# โœ… Structure stream: Pairwise matrix processing
self.pairwise_projection = nn.Sequential(
nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.Dropout(dropout)
)
# Structure transformer layers
self.structure_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
) for _ in range(num_layers // 2) # ์ ˆ๋ฐ˜๋งŒ ์‚ฌ์šฉ
])
# โœ… Cross-modal fusion layers (ํ•ต์‹ฌ!)
self.fusion_layers = nn.ModuleList([
CrossModalFusionLayer(hidden_dim, num_heads)
for _ in range(1) # fusion์€ ํ•˜๋‚˜๋งŒ ์จ์•ผ gate๊ฐ€ ์œ ์˜๋ฏธํ•ด์งˆ๋“ฏ
])
# Adaptive pooling
self.adaptive_pooler = MultiScaleAdaptivePooler(hidden_dim, num_heads)
# Final classification head (์ด์ œ ๋‹จ์ผ ์ฐจ์›)
output_dim = 1 if num_classes == 2 else num_classes
self.classification_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), # ๋” ์ด์ƒ concat ์•ˆํ•จ
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim)
)
def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# 1. Initialize both streams
x_emb = self.input_projection(x)
x_emb = x_emb + self.pos_encoding[:seq_len].unsqueeze(0)
# 2. Calculate pairwise matrix
x_expanded = x.unsqueeze(2)
x_transposed = x.unsqueeze(1)
distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
pairwise_matrix = torch.exp(-distances)
if padding_mask is not None:
pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
# 3. Process structure stream
x_struct = pairwise_matrix.unsqueeze(1)
x_struct = x_struct.view(batch_size * seq_len, 1, seq_len)
x_struct = self.pairwise_projection(x_struct)
x_struct = x_struct.mean(dim=2)
x_struct = x_struct.view(batch_size, seq_len, -1)
x_struct = x_struct + self.pos_encoding[:seq_len].unsqueeze(0)
for struct_layer in self.structure_layers:
x_struct = struct_layer(x_struct, src_key_padding_mask=padding_mask)
# 4. Process embedding stream (with pairwise guidance)
for emb_layer in self.embedding_layers:
x_emb = emb_layer(x_emb, pairwise_matrix, padding_mask)
# โœ… 5. Progressive Cross-modal Fusion (ํ•ต์‹ฌ!)
fused = x_emb # ์‹œ์ž‘์€ embedding์—์„œ
for fusion_layer in self.fusion_layers:
fused = fusion_layer(fused, x_struct, padding_mask)
# ์ด์ œ fused๋Š” embedding + structure ์ •๋ณด๋ฅผ ๋ชจ๋‘ ํฌํ•จ
# 6. Final pooling and classification
pooled = self.adaptive_pooler(fused, padding_mask)
pooled = pooled.half()
return self.classification_head(pooled)
import torch