File size: 2,307 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
from fengshen.pipelines.multiplechoice import UniMCPipelines


def main():
    total_parser = argparse.ArgumentParser("TASK NAME")
    total_parser = UniMCPipelines.piplines_args(total_parser)
    args = total_parser.parse_args()

    pretrained_model_path = 'IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese'
    args.learning_rate = 2e-5
    args.max_length = 512
    args.max_epochs = 3
    args.batchsize = 8
    args.train = 'train'
    args.default_root_dir = './'

    model = UniMCPipelines(args, model_path=pretrained_model_path)

    train_data = [    # 训练数据
        {
            "texta": "凌云研发的国产两轮电动车怎么样,有什么惊喜?",
            "textb": "",
            "question": "下面新闻属于哪一个类别?",
            "choice": [
                "教育",
                "科技",
                "军事",
                "旅游",
                "国际",
                "股票",
                "农业",
                "电竞"
            ],
            "answer": "科技",
            "label": 1,
            "id": 0
        }
    ]
    dev_data = [     # 验证数据
        {
            "texta": "我四千一个月,老婆一千五一个月,存款八万且有两小孩,是先买房还是先买车?",
            "textb": "",
            "question": "下面新闻属于哪一个类别?",
            "choice": [
                "故事",
                "文化",
                "娱乐",
                "体育",
                "财经",
                "房产",
                "汽车"
            ],
            "answer": "汽车",
            "label": 6,
            "id": 0
        }
    ]
    test_data = [    # 测试数据
        {"texta": "街头偶遇2018款长安CS35,颜值美炸!或售6万起,还买宝骏510?",
         "textb": "",
         "question": "下面新闻属于哪一个类别?",
         "choice": [
             "房产",
             "汽车",
             "教育",
             "军事"
         ],
         "answer": "汽车",
         "label": 1,
         "id": 7759}
    ]

    if args.train:
        model.train(train_data, dev_data)
    result = model.predict(test_data)
    for line in result:
        print(line)


if __name__ == "__main__":
    main()