Spaces:
Build error
Build error
import torch, argparse | |
from commonsense_model import CommonsenseGRUModel | |
from dataloader import RobertaCometDataset | |
from torch.utils.data import DataLoader | |
def load_model(model_path, args): | |
emo_gru = True | |
n_classes = 15 | |
cuda = args.cuda | |
D_m = 1024 | |
D_s = 768 | |
D_g = 150 | |
D_p = 150 | |
D_r = 150 | |
D_i = 150 | |
D_h = 100 | |
D_a = 100 | |
D_e = D_p + D_r + D_i | |
model = CommonsenseGRUModel( | |
D_m, | |
D_s, | |
D_g, | |
D_p, | |
D_r, | |
D_i, | |
D_e, | |
D_h, | |
D_a, | |
n_classes=n_classes, | |
listener_state=args.active_listener, | |
context_attention=args.attention, | |
dropout_rec=args.rec_dropout, | |
dropout=args.dropout, | |
emo_gru=emo_gru, | |
mode1=args.mode1, | |
norm=args.norm, | |
residual=args.residual, | |
) | |
if cuda: | |
model.cuda() | |
model.load_state_dict(torch.load(model_path)) | |
else: | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.eval() | |
return model | |
def get_valid_dataloader( | |
roberta_features_path: str, | |
comet_features_path: str, | |
batch_size=1, | |
num_workers=0, | |
pin_memory=False, | |
): | |
valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path) | |
test_loader = DataLoader( | |
valid_set, | |
batch_size=batch_size, | |
collate_fn=valid_set.collate_fn, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
) | |
return test_loader, valid_set.keys | |
def predict(model, data_loader, args): | |
predictions = [] | |
for data in data_loader: | |
r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = ( | |
[d.cuda() for d in data[:-1]] if args.cuda else data[:-1] | |
) | |
log_prob, _, alpha, alpha_f, alpha_b, _ = model( | |
r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask | |
) | |
lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2]) | |
preds = torch.argmax(lp_, dim=-1) | |
predictions.append(preds.data.cpu().numpy()) | |
return predictions | |
def parse_cosmic_args(): | |
parser = argparse.ArgumentParser() | |
# Parse arguments input into the cosmic model | |
parser.add_argument( | |
"--no-cuda", action="store_true", default=True, help="does not use GPU" | |
) | |
parser.add_argument( | |
"--lr", type=float, default=0.0001, metavar="LR", help="learning rate" | |
) | |
parser.add_argument( | |
"--l2", | |
type=float, | |
default=0.00003, | |
metavar="L2", | |
help="L2 regularization weight", | |
) | |
parser.add_argument( | |
"--rec-dropout", | |
type=float, | |
default=0.3, | |
metavar="rec_dropout", | |
help="rec_dropout rate", | |
) | |
parser.add_argument( | |
"--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate" | |
) | |
parser.add_argument( | |
"--batch-size", type=int, default=1, metavar="BS", help="batch size" | |
) | |
parser.add_argument( | |
"--epochs", type=int, default=10, metavar="E", help="number of epochs" | |
) | |
parser.add_argument( | |
"--class-weight", action="store_true", default=True, help="use class weights" | |
) | |
parser.add_argument( | |
"--active-listener", action="store_true", default=True, help="active listener" | |
) | |
parser.add_argument( | |
"--attention", default="simple", help="Attention type in context GRU" | |
) | |
parser.add_argument( | |
"--tensorboard", | |
action="store_true", | |
default=False, | |
help="Enables tensorboard log", | |
) | |
parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use") | |
parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed") | |
parser.add_argument("--norm", type=int, default=0, help="normalization strategy") | |
parser.add_argument("--mu", type=float, default=0, help="class_weight_mu") | |
parser.add_argument( | |
"--residual", action="store_true", default=True, help="use residual connection" | |
) | |
args = parser.parse_args() | |
args.cuda = torch.cuda.is_available() and not args.no_cuda | |
if args.cuda: | |
print("Running on GPU") | |
else: | |
print("Running on CPU") | |
return args | |
if __name__ == "__main__": | |
def pred_to_labels(preds): | |
mapped_predictions = [] | |
for pred in preds: | |
# map the prediction for each conversation | |
mapped_labels = [] | |
for label in pred: | |
mapped_labels.append(label_mapping[label]) | |
mapped_predictions.append(mapped_labels) | |
# return the mapped labels for each conversation | |
return mapped_predictions | |
label_mapping = { | |
0: "Curiosity", | |
1: "Obscene", | |
2: "Informative", | |
3: "Openness", | |
4: "Acceptance", | |
5: "Interest", | |
6: "Greeting", | |
7: "Disapproval", | |
8: "Denial", | |
9: "Anxious", | |
10: "Uninterested", | |
11: "Remorse", | |
12: "Confused", | |
13: "Accusatory", | |
14: "Annoyed", | |
} | |
args = parse_cosmic_args() | |
model = load_model("epik/best_model.pt", args) | |
test_dataloader, ids = get_valid_dataloader() | |
predicted_labels = pred_to_labels(predict(model, test_dataloader, args)) | |
for id, labels in zip(ids, predicted_labels): | |
print(f"Conversation ID: {id}") | |
print(f"Predicted Sentiment Labels: {labels}") | |
print(len(labels)) | |