Epik / Model /COSMIC /erc_training /predict_epik.py
Minh Q. Le
Fixed running on CPU
0935f1b
raw
history blame
5.48 kB
import torch, argparse
from commonsense_model import CommonsenseGRUModel
from dataloader import RobertaCometDataset
from torch.utils.data import DataLoader
def load_model(model_path, args):
emo_gru = True
n_classes = 15
cuda = args.cuda
D_m = 1024
D_s = 768
D_g = 150
D_p = 150
D_r = 150
D_i = 150
D_h = 100
D_a = 100
D_e = D_p + D_r + D_i
model = CommonsenseGRUModel(
D_m,
D_s,
D_g,
D_p,
D_r,
D_i,
D_e,
D_h,
D_a,
n_classes=n_classes,
listener_state=args.active_listener,
context_attention=args.attention,
dropout_rec=args.rec_dropout,
dropout=args.dropout,
emo_gru=emo_gru,
mode1=args.mode1,
norm=args.norm,
residual=args.residual,
)
if cuda:
model.cuda()
model.load_state_dict(torch.load(model_path))
else:
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
def get_valid_dataloader(
roberta_features_path: str,
comet_features_path: str,
batch_size=1,
num_workers=0,
pin_memory=False,
):
valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path)
test_loader = DataLoader(
valid_set,
batch_size=batch_size,
collate_fn=valid_set.collate_fn,
num_workers=num_workers,
pin_memory=pin_memory,
)
return test_loader, valid_set.keys
def predict(model, data_loader, args):
predictions = []
for data in data_loader:
r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = (
[d.cuda() for d in data[:-1]] if args.cuda else data[:-1]
)
log_prob, _, alpha, alpha_f, alpha_b, _ = model(
r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask
)
lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2])
preds = torch.argmax(lp_, dim=-1)
predictions.append(preds.data.cpu().numpy())
return predictions
def parse_cosmic_args():
parser = argparse.ArgumentParser()
# Parse arguments input into the cosmic model
parser.add_argument(
"--no-cuda", action="store_true", default=True, help="does not use GPU"
)
parser.add_argument(
"--lr", type=float, default=0.0001, metavar="LR", help="learning rate"
)
parser.add_argument(
"--l2",
type=float,
default=0.00003,
metavar="L2",
help="L2 regularization weight",
)
parser.add_argument(
"--rec-dropout",
type=float,
default=0.3,
metavar="rec_dropout",
help="rec_dropout rate",
)
parser.add_argument(
"--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate"
)
parser.add_argument(
"--batch-size", type=int, default=1, metavar="BS", help="batch size"
)
parser.add_argument(
"--epochs", type=int, default=10, metavar="E", help="number of epochs"
)
parser.add_argument(
"--class-weight", action="store_true", default=True, help="use class weights"
)
parser.add_argument(
"--active-listener", action="store_true", default=True, help="active listener"
)
parser.add_argument(
"--attention", default="simple", help="Attention type in context GRU"
)
parser.add_argument(
"--tensorboard",
action="store_true",
default=False,
help="Enables tensorboard log",
)
parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use")
parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed")
parser.add_argument("--norm", type=int, default=0, help="normalization strategy")
parser.add_argument("--mu", type=float, default=0, help="class_weight_mu")
parser.add_argument(
"--residual", action="store_true", default=True, help="use residual connection"
)
args = parser.parse_args()
args.cuda = torch.cuda.is_available() and not args.no_cuda
if args.cuda:
print("Running on GPU")
else:
print("Running on CPU")
return args
if __name__ == "__main__":
def pred_to_labels(preds):
mapped_predictions = []
for pred in preds:
# map the prediction for each conversation
mapped_labels = []
for label in pred:
mapped_labels.append(label_mapping[label])
mapped_predictions.append(mapped_labels)
# return the mapped labels for each conversation
return mapped_predictions
label_mapping = {
0: "Curiosity",
1: "Obscene",
2: "Informative",
3: "Openness",
4: "Acceptance",
5: "Interest",
6: "Greeting",
7: "Disapproval",
8: "Denial",
9: "Anxious",
10: "Uninterested",
11: "Remorse",
12: "Confused",
13: "Accusatory",
14: "Annoyed",
}
args = parse_cosmic_args()
model = load_model("epik/best_model.pt", args)
test_dataloader, ids = get_valid_dataloader()
predicted_labels = pred_to_labels(predict(model, test_dataloader, args))
for id, labels in zip(ids, predicted_labels):
print(f"Conversation ID: {id}")
print(f"Predicted Sentiment Labels: {labels}")
print(len(labels))