|
from classifier_utils import * |
|
|
|
|
|
TQDM_DISABLE=True |
|
|
|
|
|
def unsup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05): |
|
''' |
|
embeds_1: [batch_size, hidden_size] |
|
embeds_2: [batch_size, hidden_size] |
|
''' |
|
|
|
|
|
sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp |
|
|
|
|
|
positive_sim = torch.diagonal(sim_matrix) |
|
|
|
|
|
nume = torch.exp(positive_sim) |
|
|
|
|
|
deno = torch.exp(sim_matrix).sum(1) |
|
|
|
|
|
loss_per_batch = -torch.log(nume / deno) |
|
|
|
return loss_per_batch.sum() |
|
|
|
|
|
def sup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, embeds_3: Tensor, temp=0.05): |
|
''' |
|
embeds_1: [batch_size, hidden_size] |
|
embeds_2: [batch_size, hidden_size] |
|
embeds_3: [batch_size, hidden_size] |
|
''' |
|
|
|
|
|
pos_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp |
|
neg_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_3.unsqueeze(0), dim=-1) / temp |
|
|
|
|
|
positive_sim = torch.diagonal(pos_sim_matrix) |
|
|
|
|
|
nume = torch.exp(positive_sim) |
|
|
|
|
|
deno = (torch.exp(pos_sim_matrix) + torch.exp(neg_sim_matrix)).sum(1) |
|
|
|
|
|
loss_per_batch = -torch.log(nume / deno) |
|
|
|
return loss_per_batch.sum() |
|
|
|
|
|
def sts_eval(dataloader, model: BertModel, device): |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
sent_ids = [] |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(dataloader, desc='eval', leave=False, disable=TQDM_DISABLE): |
|
token_ids_1 = batch['token_ids_1'].to(device) |
|
token_ids_2 = batch['token_ids_2'].to(device) |
|
attention_mask_1 = batch['attention_mask_1'].to(device) |
|
attention_mask_2 = batch['attention_mask_2'].to(device) |
|
|
|
scores = batch['score'] |
|
b_sent_ids = batch['sent_ids'] |
|
|
|
logits_1 = model(token_ids_1, attention_mask_1)['pooler_output'] |
|
logits_2 = model(token_ids_2, attention_mask_2)['pooler_output'] |
|
|
|
sim = F.cosine_similarity(logits_1, logits_2) |
|
y_true.extend(scores) |
|
y_pred.extend(sim.cpu().tolist()) |
|
sent_ids.extend(b_sent_ids) |
|
|
|
spearman_corr, _ = spearmanr(y_pred, y_true) |
|
return spearman_corr, b_sent_ids |
|
|
|
|
|
def finetune_bert(args): |
|
''' |
|
Finetuning Baseline |
|
------------------- |
|
1. Load the Amazon Polarity (train) and STS Dataset (dev). |
|
2. Initialize pretrained minBERT |
|
3. Looping through 10 epoches. |
|
4. Calculate batches' SimCSE loss function. |
|
5. Backpropagation using Adam Optimizer. |
|
6. Evaluation on dev dataset: |
|
6.1. Create two [CLS] embeddings for given pair. |
|
6.2. Calculate their cosine similarity (0 <= sim <= 1). |
|
6.3. Spearman's correlation between calculated sim and expected sim. |
|
7. Better spearman's correlation (dev_acc > best_dev_acc) -> save_model(...). |
|
''' |
|
|
|
assert args.mode in ['unsup', 'sup'] |
|
|
|
seed_everything(SEED) |
|
torch.set_num_threads(NUM_CPU_CORES) |
|
|
|
if args.mode == 'unsup': |
|
train_dataset = AmazonDataset(load_data(AMAZON_POLARITY, 'amazon')) |
|
else: |
|
train_dataset = InferenceDataset(load_data(NLI_TRAIN, 'nli')) |
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_train, |
|
num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn) |
|
|
|
sts_dataset = SemanticDataset(load_data(STSB_DEV, 'stsb')) |
|
sts_dataloader = DataLoader(sts_dataset, shuffle=False, batch_size=args.batch_size_dev, |
|
num_workers=NUM_CPU_CORES, collate_fn=sts_dataset.collate_fn) |
|
|
|
device = torch.device('cuda') if USE_GPU else torch.device('cpu') |
|
model = BertModel.from_pretrained('bert-base-uncased') |
|
model.to(device) |
|
|
|
best_dev_acc = 0 |
|
optimizer = AdamW(model.parameters(), lr=args.lr) |
|
|
|
print(f'Finetuning minBERT with {args.mode}ervised method...') |
|
|
|
for epoch in range(EPOCHS): |
|
model.train() |
|
train_loss = num_batches = 0 |
|
|
|
for batch in tqdm(train_dataloader, f'train-{epoch}', leave=False, disable=TQDM_DISABLE): |
|
if args.mode == 'unsup': |
|
b_ids = batch['token_ids'].to(device) |
|
b_mask = batch['attention_mask'].to(device) |
|
|
|
|
|
logits_1 = model(b_ids, b_mask)['pooler_output'] |
|
logits_2 = model(b_ids, b_mask)['pooler_output'] |
|
|
|
|
|
loss = unsup_contrastive_loss(logits_1, logits_2, args.temp) |
|
|
|
else: |
|
b_anchor_ids = batch['anchor_ids'].to(device) |
|
b_positive_ids = batch['positive_ids'].to(device) |
|
b_negative_ids = batch['negative_ids'].to(device) |
|
b_anchor_masks = batch['anchor_masks'].to(device) |
|
b_positive_masks = batch['positive_masks'].to(device) |
|
b_negative_masks = batch['negative_masks'].to(device) |
|
|
|
logits_1 = model(b_anchor_ids, b_anchor_masks)['pooler_output'] |
|
logits_2 = model(b_positive_ids, b_positive_masks)['pooler_output'] |
|
logits_3 = model(b_negative_ids, b_negative_masks)['pooler_output'] |
|
|
|
loss = sup_contrastive_loss(logits_1, logits_2, logits_3, args.temp) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
num_batches += 1 |
|
|
|
train_loss /= num_batches |
|
dev_acc, _ = sts_eval(sts_dataloader, model, device) |
|
|
|
if dev_acc > best_dev_acc: |
|
best_dev_acc = dev_acc |
|
torch.save(model.state_dict(), args.filepath) |
|
print(f"save the model to {args.filepath}") |
|
|
|
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, dev acc :: {dev_acc :.3f}") |