minBERT / classifier.py
GlowCheese's picture
Final touch
8eff58f
from classifier_utils import *
TQDM_DISABLE=True
class BertSentimentClassifier(torch.nn.Module):
def __init__(self, config, custom_bert = None):
super(BertSentimentClassifier, self).__init__()
self.num_labels = config.num_labels
self.bert: BertModel = custom_bert or BertModel.from_pretrained('bert-base-uncased')
# Pretrain mode does not require updating BERT paramters.
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
for param in self.bert.parameters():
if config.fine_tune_mode == 'last-linear-layer':
param.requires_grad = False
elif config.fine_tune_mode == 'full-model':
param.requires_grad = True
# Classifier = Dropout + Linear
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask)
pooler_output = outputs['pooler_output']
return self.classifier(self.dropout(pooler_output))
# Evaluate the model on dev examples.
def model_eval(dataloader, model: BertSentimentClassifier, device):
model.eval() # Switch to eval model, will turn off randomness like dropout.
y_true = []
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
b_labels, b_sents, b_sent_ids = batch['labels'], batch['sents'], batch['sent_ids']
b_ids = batch['token_ids'].to(device)
b_mask = batch['attention_mask'].to(device)
logits = model(b_ids, b_mask)
logits = logits.detach().cpu().numpy()
preds = np.argmax(logits, axis=1).flatten()
b_labels = b_labels.flatten()
y_true.extend(b_labels)
y_pred.extend(preds)
sents.extend(b_sents)
sent_ids.extend(b_sent_ids)
f1 = f1_score(y_true, y_pred, average='macro')
acc = accuracy_score(y_true, y_pred)
return acc, f1, y_pred, y_true, sents, sent_ids
# Evaluate the model on test examples.
def model_test_eval(dataloader, model, device):
model.eval() # Switch to eval model, will turn off randomness like dropout.
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
b_sents, b_sent_ids = batch['sents'], batch['sent_ids']
b_ids = batch['token_ids'].to(device)
b_mask = batch['attention_mask'].to(device)
logits = model(b_ids, b_mask)
logits = logits.detach().cpu().numpy()
preds = np.argmax(logits, axis=1).flatten()
y_pred.extend(preds)
sents.extend(b_sents)
sent_ids.extend(b_sent_ids)
return y_pred, sents, sent_ids
def save_model(model, args, config, filepath):
save_info = {
'model': model.state_dict(),
'args': args,
'model_config': config,
'system_rng': random.getstate(),
'numpy_rng': np.random.get_state(),
'torch_rng': torch.random.get_rng_state(),
}
torch.save(save_info, filepath)
print(f"save the model to {filepath}")
def train(args, custom_bert=None):
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
# Create the data and its corresponding datasets and dataloader.
train_data, num_labels = load_data(args.train, 'train')
dev_data = load_data(args.dev, 'valid')
train_dataset = SentimentDataset(train_data)
dev_dataset = SentimentDataset(dev_data)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
# Init model.
config = {'hidden_dropout_prob': HIDDEN_DROPOUT_PROB,
'num_labels': num_labels,
'hidden_size': 768,
'data_dir': '.',
'fine_tune_mode': args.fine_tune_mode}
config = SimpleNamespace(**config)
model = BertSentimentClassifier(config, custom_bert)
model = model.to(device)
lr = args.lr
optimizer = AdamW(model.parameters(), lr=lr)
best_dev_acc = 0
# Run for the specified number of epochs.
for epoch in range(EPOCHS):
model.train()
train_loss = 0
num_batches = 0
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
b_ids = batch['token_ids'].to(device)
b_mask = batch['attention_mask'].to(device)
b_labels = batch['labels'].to(device)
optimizer.zero_grad()
logits = model(b_ids, b_mask)
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
loss.backward()
optimizer.step()
train_loss += loss.item()
num_batches += 1
train_loss = train_loss / (num_batches)
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
save_model(model, args, config, args.filepath)
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
def test(args):
with torch.no_grad():
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
saved = torch.load(args.filepath, weights_only=False)
config = saved['model_config']
model = BertSentimentClassifier(config)
model.load_state_dict(saved['model'])
model = model.to(device)
print(f"load model from {args.filepath}")
dev_data = load_data(args.dev, 'valid')
dev_dataset = SentimentDataset(dev_data)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
print('DONE DEV')
print(f"dev acc :: {dev_acc :.3f}")
def classifier_run(args, custom_bert=None):
seed_everything(SEED)
torch.set_num_threads(NUM_CPU_CORES)
print(f'Training Sentiment Classifier on {args.dataset}...')
config = SimpleNamespace(
filepath=f'{args.dataset}-classifier.pt',
lr=args.lr,
batch_size=args.batch_size,
fine_tune_mode=args.fine_tune_mode,
train=args.train, dev=args.dev, test=args.test,
dev_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-dev-out.csv',
test_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-test-out.csv'
)
train(config, custom_bert)
print(f'Evaluating on {args.dataset}...')
test(config)