File size: 4,132 Bytes
afab310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time    : 2021/2/29 21:41
# @author  : Mo
# @function: transformers直接加载bert类模型测试


import traceback
import time
import sys
import os
os.environ["USE_TORCH"] = "1"

from transformers import BertConfig, BertTokenizer, BertForMaskedLM
import gradio as gr
import torch



# pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
# pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
device = torch.device("cpu")
# device = torch.device("cuda")
max_len = 128

print("load model, please wait a few minute!")
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
bert_config = BertConfig.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
model.to(device)
print("load model success!")

texts = [
    "机七学习是人工智能领遇最能体现智能的一个分知",
    "我是练习时长两念半的鸽仁练习生蔡徐坤",
]
len_mid = min(max_len, max([len(t)+2 for t in texts]))

with torch.no_grad():
    outputs = model(**tokenizer(texts, padding=True, max_length=len_mid,
                                return_tensors="pt").to(device))

def get_errors(source, target):
    """   极简方法获取 errors   """
    len_min = min(len(source), len(target))
    errors = []
    for idx in range(len_min):
        if source[idx] != target[idx]:
            errors.append([source[idx], target[idx], idx])
    return errors

result = []
for probs, source in zip(outputs.logits, texts):
    ids = torch.argmax(probs, dim=-1)
    tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
    text_new = tokens_space.replace(" ", "")
    target = text_new[:len(source)]
    errors = get_errors(source, target)
    print(source, " => ", target, errors)
    result.append([target, errors])
print(result)


def macro_correct(text):
    with torch.no_grad():
        outputs = model(**tokenizer([text], padding=True, max_length=max_len,
                                    return_tensors="pt").to(device))

    def to_highlight(corrected_sent, errs):
        output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
                  enumerate(errs)]
        return {"text": corrected_sent, "entities": output}

    def get_errors(source, target):
        """   极简方法获取 errors   """
        len_min = min(len(source), len(target))
        errors = []
        for idx in range(len_min):
            if source[idx] != target[idx]:
                errors.append([source[idx], target[idx], idx])
        return errors

    result = []
    for probs, source in zip(outputs.logits, texts):
        ids = torch.argmax(probs, dim=-1)
        tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
        text_new = tokens_space.replace(" ", "")
        target = text_new[:len(source)]
        errors = get_errors(source, target)
        print(source, " => ", target, errors)
        result.append([target, errors])
        # print(result)
        return target + " " + str(errors)


if __name__ == '__main__':
    print(macro_correct('少先队员因该为老人让坐'))

    text_sample = '机七学习是人工智能领遇最能体现智能的一个分知'

    gr.Interface(
        macro_correct,
        inputs='text',
        outputs='text',
        title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
        description="Copy or input error Chinese text. Submit and the machine will correct text.",
        article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
        examples=text_sample
    ).launch()
    # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)