Spaces:
Running
Running
File size: 4,132 Bytes
afab310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time : 2021/2/29 21:41
# @author : Mo
# @function: transformers直接加载bert类模型测试
import traceback
import time
import sys
import os
os.environ["USE_TORCH"] = "1"
from transformers import BertConfig, BertTokenizer, BertForMaskedLM
import gradio as gr
import torch
# pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
# pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
# pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
device = torch.device("cpu")
# device = torch.device("cuda")
max_len = 128
print("load model, please wait a few minute!")
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
bert_config = BertConfig.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
model.to(device)
print("load model success!")
texts = [
"机七学习是人工智能领遇最能体现智能的一个分知",
"我是练习时长两念半的鸽仁练习生蔡徐坤",
]
len_mid = min(max_len, max([len(t)+2 for t in texts]))
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, max_length=len_mid,
return_tensors="pt").to(device))
def get_errors(source, target):
""" 极简方法获取 errors """
len_min = min(len(source), len(target))
errors = []
for idx in range(len_min):
if source[idx] != target[idx]:
errors.append([source[idx], target[idx], idx])
return errors
result = []
for probs, source in zip(outputs.logits, texts):
ids = torch.argmax(probs, dim=-1)
tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
text_new = tokens_space.replace(" ", "")
target = text_new[:len(source)]
errors = get_errors(source, target)
print(source, " => ", target, errors)
result.append([target, errors])
print(result)
def macro_correct(text):
with torch.no_grad():
outputs = model(**tokenizer([text], padding=True, max_length=max_len,
return_tensors="pt").to(device))
def to_highlight(corrected_sent, errs):
output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
enumerate(errs)]
return {"text": corrected_sent, "entities": output}
def get_errors(source, target):
""" 极简方法获取 errors """
len_min = min(len(source), len(target))
errors = []
for idx in range(len_min):
if source[idx] != target[idx]:
errors.append([source[idx], target[idx], idx])
return errors
result = []
for probs, source in zip(outputs.logits, texts):
ids = torch.argmax(probs, dim=-1)
tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False)
text_new = tokens_space.replace(" ", "")
target = text_new[:len(source)]
errors = get_errors(source, target)
print(source, " => ", target, errors)
result.append([target, errors])
# print(result)
return target + " " + str(errors)
if __name__ == '__main__':
print(macro_correct('少先队员因该为老人让坐'))
text_sample = '机七学习是人工智能领遇最能体现智能的一个分知'
gr.Interface(
macro_correct,
inputs='text',
outputs='text',
title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
description="Copy or input error Chinese text. Submit and the machine will correct text.",
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
examples=text_sample
).launch()
# ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
|