File size: 4,987 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
from fengshen import UbertPipelines
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6'


def main():
    total_parser = argparse.ArgumentParser("TASK NAME")
    total_parser = UbertPipelines.pipelines_args(total_parser)
    args = total_parser.parse_args()

    # 设置一些训练要使用到的参数
    args.pretrained_model_path = 'IDEA-CCNL/Erlangshen-Ubert-110M-Chinese' #预训练模型的路径,我们提供的预训练模型存放在HuggingFace上
    args.default_root_dir = './'  #默认主路径,用来放日志、tensorboard等
    args.max_epochs = 5
    args.gpus = 1
    args.batch_size = 1

    # 只需要将数据处理成为下面数据的 json 样式就可以一键训练和预测,下面只是提供了一条示例样本
    train_data = [
        {
            "task_type": "抽取任务",
            "subtask_type": "实体识别",
            "text": "彭小军认为,国内银行现在走的是台湾的发卡模式,先通过跑马圈地再在圈的地里面选择客户,",
            "choices": [
                {"entity_type": "地址", "label": 0, "entity_list": [
                    {"entity_name": "台湾", "entity_type": "地址", "entity_idx": [[15, 16]]}]},
                {"entity_type": "书名", "label": 0, "entity_list": []},
                {"entity_type": "公司", "label": 0, "entity_list": []},
                {"entity_type": "游戏", "label": 0, "entity_list": []},
                {"entity_type": "政府机构", "label": 0, "entity_list": []},
                {"entity_type": "电影名称", "label": 0, "entity_list": []},
                {"entity_type": "人物姓名", "label": 0, "entity_list": [
                    {"entity_name": "彭小军", "entity_type": "人物姓名", "entity_idx": [[0, 2]]}]},
                {"entity_type": "组织机构", "label": 0, "entity_list": []},
                {"entity_type": "岗位职位", "label": 0, "entity_list": []},
                {"entity_type": "旅游景点", "label": 0, "entity_list": []}
            ],
            "id": 0}
    ]
    dev_data = [
        {
            "task_type": "抽取任务",
            "subtask_type": "实体识别",
            "text": "就天涯网推出彩票服务频道是否是业内人士所谓的打政策“擦边球”,记者近日对此事求证彩票监管部门。",
            "choices": [
                {"entity_type": "地址", "label": 0, "entity_list": []},
                {"entity_type": "书名", "label": 0, "entity_list": []},
                {"entity_type": "公司", "label": 0, "entity_list": [
                    {"entity_name": "天涯网", "entity_type": "公司", "entity_idx": [[1, 3]]}]},
                {"entity_type": "游戏", "label": 0, "entity_list": []},
                {"entity_type": "政府机构", "label": 0, "entity_list": []},
                {"entity_type": "电影名称", "label": 0, "entity_list": []},
                {"entity_type": "人物姓名", "label": 0, "entity_list": []},
                {"entity_type": "组织机构", "label": 0, "entity_list": [
                    {"entity_name": "彩票监管部门", "entity_type": "组织机构", "entity_idx": [[40, 45]]}]},
                {"entity_type": "岗位职位", "label": 0, "entity_list": [
                    {"entity_name": "记者", "entity_type": "岗位职位", "entity_idx": [[31, 32]]}]},
                {"entity_type": "旅游景点", "label": 0, "entity_list": []}
            ],

            "id": 0}

    ]
    test_data = [
        {
            "task_type": "抽取任务",
            "subtask_type": "实体识别",
            "text": "这也让很多业主据此认为,雅清苑是政府公务员挤对了国家的经适房政策。",
            "choices": [
                {"entity_type": "地址", "label": 0, "entity_list": [
                    {"entity_name": "雅清苑", "entity_type": "地址", "entity_idx": [[12, 14]]}]},
                {"entity_type": "书名", "label": 0, "entity_list": []},
                {"entity_type": "公司", "label": 0, "entity_list": []},
                {"entity_type": "游戏", "label": 0, "entity_list": []},
                {"entity_type": "政府机构", "label": 0, "entity_list": []},
                {"entity_type": "电影名称", "label": 0, "entity_list": []},
                {"entity_type": "人物姓名", "label": 0, "entity_list": []},
                {"entity_type": "组织机构", "label": 0, "entity_list": []},
                {"entity_type": "岗位职位", "label": 0, "entity_list": [
                    {"entity_name": "公务员", "entity_type": "岗位职位", "entity_idx": [[18, 20]]}]},
                {"entity_type": "旅游景点", "label": 0, "entity_list": []}
            ],
            "id": 0},
    ]

    model = UbertPipelines(args)
    model.fit(train_data, dev_data)
    result = model.predict(test_data)
    for line in result:
        print(line)


if __name__ == "__main__":
    main()