Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-2500/") | |
| def text_processing(text): | |
| text = text + ' ' if text[-2:] != ' ' else text # 在末尾加上空格有利于模型预测 | |
| inputs = [text] | |
| # Tokenize and prepare the inputs for model | |
| input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids | |
| attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask | |
| # Generate prediction | |
| output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512) | |
| # Decode the prediction | |
| decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output] | |
| return decoded_output[0] | |
| examples = [ | |
| ["我们的价值观是 富强 民主 文明 和谐"], | |
| ["都什么年代了 还在抽传统香烟"], | |
| ["今夕是何年"], | |
| [" 三国演义 全名为 三國志通俗演义 又稱作 三國志演義 三國志傳 三國傳 三國全傳 三國英雄志傳 "], | |
| ] | |
| inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")] | |
| iface = gr.Interface( | |
| fn=text_processing, | |
| inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")], | |
| outputs='text', | |
| title='Punctuation Mark Prediction', | |
| description='本模型主要用于语音识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。', | |
| examples=examples | |
| ) | |
| iface.launch(inline=False) |