|
from classifier_utils import * |
|
|
|
|
|
TQDM_DISABLE=True |
|
|
|
|
|
class BertSentimentClassifier(torch.nn.Module): |
|
def __init__(self, config, custom_bert = None): |
|
super(BertSentimentClassifier, self).__init__() |
|
self.num_labels = config.num_labels |
|
self.bert: BertModel = custom_bert or BertModel.from_pretrained('bert-base-uncased') |
|
|
|
|
|
assert config.fine_tune_mode in ["last-linear-layer", "full-model"] |
|
for param in self.bert.parameters(): |
|
if config.fine_tune_mode == 'last-linear-layer': |
|
param.requires_grad = False |
|
elif config.fine_tune_mode == 'full-model': |
|
param.requires_grad = True |
|
|
|
|
|
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.bert(input_ids, attention_mask) |
|
pooler_output = outputs['pooler_output'] |
|
|
|
return self.classifier(self.dropout(pooler_output)) |
|
|
|
|
|
|
|
def model_eval(dataloader, model: BertSentimentClassifier, device): |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
sents = [] |
|
sent_ids = [] |
|
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
|
b_labels, b_sents, b_sent_ids = batch['labels'], batch['sents'], batch['sent_ids'] |
|
|
|
b_ids = batch['token_ids'].to(device) |
|
b_mask = batch['attention_mask'].to(device) |
|
|
|
logits = model(b_ids, b_mask) |
|
logits = logits.detach().cpu().numpy() |
|
preds = np.argmax(logits, axis=1).flatten() |
|
|
|
b_labels = b_labels.flatten() |
|
y_true.extend(b_labels) |
|
y_pred.extend(preds) |
|
sents.extend(b_sents) |
|
sent_ids.extend(b_sent_ids) |
|
|
|
f1 = f1_score(y_true, y_pred, average='macro') |
|
acc = accuracy_score(y_true, y_pred) |
|
|
|
return acc, f1, y_pred, y_true, sents, sent_ids |
|
|
|
|
|
|
|
def model_test_eval(dataloader, model, device): |
|
model.eval() |
|
y_pred = [] |
|
sents = [] |
|
sent_ids = [] |
|
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
|
b_sents, b_sent_ids = batch['sents'], batch['sent_ids'] |
|
|
|
b_ids = batch['token_ids'].to(device) |
|
b_mask = batch['attention_mask'].to(device) |
|
|
|
logits = model(b_ids, b_mask) |
|
logits = logits.detach().cpu().numpy() |
|
preds = np.argmax(logits, axis=1).flatten() |
|
|
|
y_pred.extend(preds) |
|
sents.extend(b_sents) |
|
sent_ids.extend(b_sent_ids) |
|
|
|
return y_pred, sents, sent_ids |
|
|
|
|
|
def save_model(model, args, config, filepath): |
|
save_info = { |
|
'model': model.state_dict(), |
|
'args': args, |
|
'model_config': config, |
|
'system_rng': random.getstate(), |
|
'numpy_rng': np.random.get_state(), |
|
'torch_rng': torch.random.get_rng_state(), |
|
} |
|
|
|
torch.save(save_info, filepath) |
|
print(f"save the model to {filepath}") |
|
|
|
|
|
def train(args, custom_bert=None): |
|
device = torch.device('cuda') if USE_GPU else torch.device('cpu') |
|
|
|
train_data, num_labels = load_data(args.train, 'train') |
|
dev_data = load_data(args.dev, 'valid') |
|
|
|
train_dataset = SentimentDataset(train_data) |
|
dev_dataset = SentimentDataset(dev_data) |
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, |
|
num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
|
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn) |
|
|
|
|
|
config = {'hidden_dropout_prob': HIDDEN_DROPOUT_PROB, |
|
'num_labels': num_labels, |
|
'hidden_size': 768, |
|
'data_dir': '.', |
|
'fine_tune_mode': args.fine_tune_mode} |
|
|
|
config = SimpleNamespace(**config) |
|
|
|
model = BertSentimentClassifier(config, custom_bert) |
|
model = model.to(device) |
|
|
|
lr = args.lr |
|
optimizer = AdamW(model.parameters(), lr=lr) |
|
best_dev_acc = 0 |
|
|
|
|
|
for epoch in range(EPOCHS): |
|
model.train() |
|
train_loss = 0 |
|
num_batches = 0 |
|
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE): |
|
b_ids = batch['token_ids'].to(device) |
|
b_mask = batch['attention_mask'].to(device) |
|
b_labels = batch['labels'].to(device) |
|
|
|
optimizer.zero_grad() |
|
logits = model(b_ids, b_mask) |
|
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
num_batches += 1 |
|
|
|
train_loss = train_loss / (num_batches) |
|
|
|
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device) |
|
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device) |
|
|
|
if dev_acc > best_dev_acc: |
|
best_dev_acc = dev_acc |
|
save_model(model, args, config, args.filepath) |
|
|
|
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}") |
|
|
|
|
|
def test(args): |
|
with torch.no_grad(): |
|
device = torch.device('cuda') if USE_GPU else torch.device('cpu') |
|
saved = torch.load(args.filepath, weights_only=False) |
|
config = saved['model_config'] |
|
model = BertSentimentClassifier(config) |
|
model.load_state_dict(saved['model']) |
|
model = model.to(device) |
|
print(f"load model from {args.filepath}") |
|
|
|
dev_data = load_data(args.dev, 'valid') |
|
dev_dataset = SentimentDataset(dev_data) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
|
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn) |
|
|
|
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device) |
|
print('DONE DEV') |
|
print(f"dev acc :: {dev_acc :.3f}") |
|
|
|
|
|
def classifier_run(args, custom_bert=None): |
|
seed_everything(SEED) |
|
torch.set_num_threads(NUM_CPU_CORES) |
|
|
|
print(f'Training Sentiment Classifier on {args.dataset}...') |
|
config = SimpleNamespace( |
|
filepath=f'{args.dataset}-classifier.pt', |
|
lr=args.lr, |
|
batch_size=args.batch_size, |
|
fine_tune_mode=args.fine_tune_mode, |
|
train=args.train, dev=args.dev, test=args.test, |
|
dev_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-dev-out.csv', |
|
test_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-test-out.csv' |
|
) |
|
|
|
train(config, custom_bert) |
|
|
|
print(f'Evaluating on {args.dataset}...') |
|
test(config) |