Macropodus commited on
Commit
afab310
·
verified ·
1 Parent(s): 0927cb9

Upload 2 files

Browse files
Files changed (2) hide show
  1. macro-correct.py +114 -0
  2. requirements.txt +3 -0
macro-correct.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+ # @time : 2021/2/29 21:41
4
+ # @author : Mo
5
+ # @function: transformers直接加载bert类模型测试
6
+
7
+
8
+ import traceback
9
+ import time
10
+ import sys
11
+ import os
12
+ os.environ["USE_TORCH"] = "1"
13
+
14
+ from transformers import BertConfig, BertTokenizer, BertForMaskedLM
15
+ import gradio as gr
16
+ import torch
17
+
18
+
19
+
20
+ # pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
21
+ pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
22
+ # pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
23
+ # pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
24
+ # pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
25
+ device = torch.device("cpu")
26
+ # device = torch.device("cuda")
27
+ max_len = 128
28
+
29
+ print("load model, please wait a few minute!")
30
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
31
+ bert_config = BertConfig.from_pretrained(pretrained_model_name_or_path)
32
+ model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
33
+ model.to(device)
34
+ print("load model success!")
35
+
36
+ texts = [
37
+ "机七学习是人工智能领遇最能体现智能的一个分知",
38
+ "我是练习时长两念半的鸽仁练习生蔡徐坤",
39
+ ]
40
+ len_mid = min(max_len, max([len(t)+2 for t in texts]))
41
+
42
+ with torch.no_grad():
43
+ outputs = model(**tokenizer(texts, padding=True, max_length=len_mid,
44
+ return_tensors="pt").to(device))
45
+
46
+ def get_errors(source, target):
47
+ """ 极简方法获取 errors """
48
+ len_min = min(len(source), len(target))
49
+ errors = []
50
+ for idx in range(len_min):
51
+ if source[idx] != target[idx]:
52
+ errors.append([source[idx], target[idx], idx])
53
+ return errors
54
+
55
+ result = []
56
+ for probs, source in zip(outputs.logits, texts):
57
+ ids = torch.argmax(probs, dim=-1)
58
+ tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
59
+ text_new = tokens_space.replace(" ", "")
60
+ target = text_new[:len(source)]
61
+ errors = get_errors(source, target)
62
+ print(source, " => ", target, errors)
63
+ result.append([target, errors])
64
+ print(result)
65
+
66
+
67
+ def macro_correct(text):
68
+ with torch.no_grad():
69
+ outputs = model(**tokenizer([text], padding=True, max_length=max_len,
70
+ return_tensors="pt").to(device))
71
+
72
+ def to_highlight(corrected_sent, errs):
73
+ output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
74
+ enumerate(errs)]
75
+ return {"text": corrected_sent, "entities": output}
76
+
77
+ def get_errors(source, target):
78
+ """ 极简方法获取 errors """
79
+ len_min = min(len(source), len(target))
80
+ errors = []
81
+ for idx in range(len_min):
82
+ if source[idx] != target[idx]:
83
+ errors.append([source[idx], target[idx], idx])
84
+ return errors
85
+
86
+ result = []
87
+ for probs, source in zip(outputs.logits, texts):
88
+ ids = torch.argmax(probs, dim=-1)
89
+ tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
90
+ text_new = tokens_space.replace(" ", "")
91
+ target = text_new[:len(source)]
92
+ errors = get_errors(source, target)
93
+ print(source, " => ", target, errors)
94
+ result.append([target, errors])
95
+ # print(result)
96
+ return target + " " + str(errors)
97
+
98
+
99
+ if __name__ == '__main__':
100
+ print(macro_correct('少先队员因该为老人让坐'))
101
+
102
+ text_sample = '机七学习是人工智能领遇最能体现智能的一个分知'
103
+
104
+ gr.Interface(
105
+ macro_correct,
106
+ inputs='text',
107
+ outputs='text',
108
+ title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
109
+ description="Copy or input error Chinese text. Submit and the machine will correct text.",
110
+ article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
111
+ examples=text_sample
112
+ ).launch()
113
+ # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
114
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch