# !/usr/bin/python # -*- coding: utf-8 -*- # @time : 2021/2/29 21:41 # @author : Mo # @function: transformers直接加载bert类模型测试 import traceback import time import sys import os os.environ["USE_TORCH"] = "1" from transformers import BertConfig, BertTokenizer, BertForMaskedLM import gradio as gr import torch # pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese" pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1" # pretrained_model_name_or_path = "Macropodus/macbert4csc_v1" # pretrained_model_name_or_path = "Macropodus/macbert4csc_v2" # pretrained_model_name_or_path = "Macropodus/bert4csc_v1" device = torch.device("cpu") # device = torch.device("cuda") max_len = 128 print("load model, please wait a few minute!") tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path) bert_config = BertConfig.from_pretrained(pretrained_model_name_or_path) model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path) model.to(device) print("load model success!") texts = [ "机七学习是人工智能领遇最能体现智能的一个分知", "我是练习时长两念半的鸽仁练习生蔡徐坤", ] len_mid = min(max_len, max([len(t)+2 for t in texts])) with torch.no_grad(): outputs = model(**tokenizer(texts, padding=True, max_length=len_mid, return_tensors="pt").to(device)) def get_errors(source, target): """ 极简方法获取 errors """ len_min = min(len(source), len(target)) errors = [] for idx in range(len_min): if source[idx] != target[idx]: errors.append([source[idx], target[idx], idx]) return errors result = [] for probs, source in zip(outputs.logits, texts): ids = torch.argmax(probs, dim=-1) tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False) text_new = tokens_space.replace(" ", "") target = text_new[:len(source)] errors = get_errors(source, target) print(source, " => ", target, errors) result.append([target, errors]) print(result) def macro_correct(text): with torch.no_grad(): outputs = model(**tokenizer([text], padding=True, max_length=max_len, return_tensors="pt").to(device)) def to_highlight(corrected_sent, errs): output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in enumerate(errs)] return {"text": corrected_sent, "entities": output} def get_errors(source, target): """ 极简方法获取 errors """ len_min = min(len(source), len(target)) errors = [] for idx in range(len_min): if source[idx] != target[idx]: errors.append([source[idx], target[idx], idx]) return errors result = [] for probs, source in zip(outputs.logits, texts): ids = torch.argmax(probs, dim=-1) tokens_space = tokenizer.decode(ids[1:-1], skip_special_tokens=False) text_new = tokens_space.replace(" ", "") target = text_new[:len(source)] errors = get_errors(source, target) print(source, " => ", target, errors) result.append([target, errors]) # print(result) return target + " " + str(errors) if __name__ == '__main__': print(macro_correct('少先队员因该为老人让坐')) text_sample = '机七学习是人工智能领遇最能体现智能的一个分知' gr.Interface( macro_correct, inputs='text', outputs='text', title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2", description="Copy or input error Chinese text. Submit and the machine will correct text.", article="Link to Github REPO: macro-correct", examples=text_sample ).launch() # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)