GlowCheese commited on
Commit
02b98f5
·
1 Parent(s): 7587354

contrastive commit 2

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. classifier.py +0 -2
  3. unsup_simcse.py +57 -11
.gitignore CHANGED
@@ -163,4 +163,5 @@ cython_debug/
163
 
164
  nohup.out
165
  *.pt
166
- predictions/
 
 
163
 
164
  nohup.out
165
  *.pt
166
+ predictions/
167
+ zemo*.py
classifier.py CHANGED
@@ -362,8 +362,6 @@ def main():
362
  seed_everything(args.seed)
363
  torch.set_num_threads(args.num_cpu_cores)
364
 
365
- print(torch.get_num_threads())
366
-
367
  print('Training Sentiment Classifier on SST...')
368
  config = SimpleNamespace(
369
  filepath='sst-classifier.pt',
 
362
  seed_everything(args.seed)
363
  torch.set_num_threads(args.num_cpu_cores)
364
 
 
 
365
  print('Training Sentiment Classifier on SST...')
366
  config = SimpleNamespace(
367
  filepath='sst-classifier.pt',
unsup_simcse.py CHANGED
@@ -3,8 +3,10 @@ import torch
3
  import random
4
  import argparse
5
  import numpy as np
 
6
 
7
  from tqdm import tqdm
 
8
  from types import SimpleNamespace
9
  from torch.utils.data import Dataset, DataLoader
10
  from sklearn.metrics import f1_score, accuracy_score
@@ -52,18 +54,22 @@ def load_data(filename, flag='train'):
52
  - for Twitter dataset: list of sentences
53
  - for SST/CFIMDB dataset: list of (sent, [label], sent_id)
54
  '''
 
55
  num_labels = set()
56
  data = []
57
  with open(filename, 'r') as fp:
58
- for record in csv.DictReader(fp, delimiter = ',', ):
59
- if flag == 'twitter':
60
  sent = record['clean_text'].lower().strip()
61
  data.append(sent)
62
- elif flag == 'test':
 
 
63
  sent = record['sentence'].lower().strip()
64
  sent_id = record['id'].lower().strip()
65
  data.append((sent,sent_id))
66
- else:
 
67
  sent = record['sentence'].lower().strip()
68
  sent_id = record['id'].lower().strip()
69
  label = int(record['sentiment'].strip())
@@ -92,6 +98,35 @@ def save_model(model, optimizer, args, config, filepath):
92
  print(f"save the model to {filepath}")
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def train(args):
96
  '''
97
  Training Pipeline
@@ -138,6 +173,7 @@ def train(args):
138
  optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
139
  best_dev_acc = 0
140
 
 
141
  for epoch in range(args.epochs):
142
  model.bert.train()
143
  train_loss = num_batches = 0
@@ -146,11 +182,21 @@ def train(args):
146
  b_ids = b_ids.to(device)
147
  b_mask = b_mask.to(device)
148
 
149
- optimizer_cse.zero_grad()
150
- logits = model.bert.embed(b_ids)
151
- logits = model.bert.encode(logits, b_mask)
 
 
 
152
 
 
 
153
 
 
 
 
 
 
154
 
155
 
156
  def get_args():
@@ -177,18 +223,18 @@ if __name__ == "__main__":
177
  print('Finetuning minBERT with Unsupervised SimCSE...')
178
  config = SimpleNamespace(
179
  filepath='contrastive-nli.pt',
180
- lr=args.lr,
 
181
  num_cpu_cores=args.num_cpu_cores,
182
  use_gpu=args.use_gpu,
183
  epochs=args.epochs,
184
  batch_size_cse=args.batch_size_cse,
185
  batch_size_classifier=args.batch_size_classifier,
 
186
  train_bert='data/twitter-unsup.csv',
187
  train='data/ids-sst-train.csv',
188
  dev='data/ids-sst-dev.csv',
189
- test='data/ids-sst-test-student.csv',
190
- dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
191
- test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
192
  )
193
 
194
  train(config)
 
3
  import random
4
  import argparse
5
  import numpy as np
6
+ import torch.nn.functional as F
7
 
8
  from tqdm import tqdm
9
+ from torch import Tensor
10
  from types import SimpleNamespace
11
  from torch.utils.data import Dataset, DataLoader
12
  from sklearn.metrics import f1_score, accuracy_score
 
54
  - for Twitter dataset: list of sentences
55
  - for SST/CFIMDB dataset: list of (sent, [label], sent_id)
56
  '''
57
+
58
  num_labels = set()
59
  data = []
60
  with open(filename, 'r') as fp:
61
+ if flag == 'twitter':
62
+ for cnt, record in enumerate(csv.DictReader(fp, delimiter = ',')):
63
  sent = record['clean_text'].lower().strip()
64
  data.append(sent)
65
+ if cnt == 10000: break
66
+ elif flag == 'test':
67
+ for record in csv.DictReader(fp, delimiter = '\t'):
68
  sent = record['sentence'].lower().strip()
69
  sent_id = record['id'].lower().strip()
70
  data.append((sent,sent_id))
71
+ else:
72
+ for record in csv.DictReader(fp, delimiter = '\t'):
73
  sent = record['sentence'].lower().strip()
74
  sent_id = record['id'].lower().strip()
75
  label = int(record['sentiment'].strip())
 
98
  print(f"save the model to {filepath}")
99
 
100
 
101
+ # def model_eval(dataloader, model, device):
102
+ # model.eval()
103
+
104
+
105
+
106
+ def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
107
+ '''
108
+ embeds_1: [batch_size, hidden_size]
109
+ embeds_2: [batch_size, hidden_size]
110
+ '''
111
+
112
+ # [batch_size, batch_size]
113
+ sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
114
+
115
+ # [batch_size]
116
+ positive_sim = torch.diagonal(sim_matrix)
117
+
118
+ # [batch_size]
119
+ nume = torch.exp(positive_sim)
120
+
121
+ # [batch_size]
122
+ deno = torch.exp(sim_matrix).sum(1)
123
+
124
+ # [batch_size]
125
+ loss_per_batch = -torch.log(nume / deno)
126
+
127
+ return loss_per_batch.mean()
128
+
129
+
130
  def train(args):
131
  '''
132
  Training Pipeline
 
173
  optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
174
  best_dev_acc = 0
175
 
176
+ # ---- Training minBERT using SimCSE ---- #
177
  for epoch in range(args.epochs):
178
  model.bert.train()
179
  train_loss = num_batches = 0
 
182
  b_ids = b_ids.to(device)
183
  b_mask = b_mask.to(device)
184
 
185
+ # Get different embeddings with different dropout masks
186
+ logits_1 = model.bert(b_ids, b_mask)['pooler_output']
187
+ logits_2 = model.bert(b_ids, b_mask)['pooler_output']
188
+
189
+ # Calculate mean SimCSE loss function
190
+ loss = contrastive_loss(logits_1, logits_2)
191
 
192
+ loss.backward()
193
+ optimizer_cse.step()
194
 
195
+ train_loss += loss.item()
196
+ num_batches += 0
197
+
198
+ train_loss = train_loss / num_batches
199
+ print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
200
 
201
 
202
  def get_args():
 
223
  print('Finetuning minBERT with Unsupervised SimCSE...')
224
  config = SimpleNamespace(
225
  filepath='contrastive-nli.pt',
226
+ lr_cse=args.lr_cse,
227
+ lr_classifier=args.lr_classifier,
228
  num_cpu_cores=args.num_cpu_cores,
229
  use_gpu=args.use_gpu,
230
  epochs=args.epochs,
231
  batch_size_cse=args.batch_size_cse,
232
  batch_size_classifier=args.batch_size_classifier,
233
+ hidden_dropout_prob=args.hidden_dropout_prob,
234
  train_bert='data/twitter-unsup.csv',
235
  train='data/ids-sst-train.csv',
236
  dev='data/ids-sst-dev.csv',
237
+ test='data/ids-sst-test-student.csv'
 
 
238
  )
239
 
240
  train(config)