|
from collections.abc import Callable |
|
import traceback |
|
from typing import List, Union |
|
from datasets import Dataset |
|
import re |
|
import pickle |
|
import os |
|
from transformers.pipelines.pt_utils import KeyDataset |
|
from transformers import AutoTokenizer |
|
from tqdm.auto import tqdm |
|
|
|
URL_REGEX = r"\b(https?://\S+)\b" |
|
EMAIL_REGEX = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" |
|
TAG_REGEX = r"<[^>]+>" |
|
HANDLE_REGEX = r"[^a-zA-Z](@\w+)" |
|
|
|
|
|
class Translator: |
|
def __init__( |
|
self, |
|
pipe: Callable, |
|
max_length: int = 500, |
|
batch_size: int = 16, |
|
save_every_step=100, |
|
text_key="text", |
|
save_filename=None, |
|
replace_chinese_puncts=False, |
|
verbose=False, |
|
): |
|
self.pipe = pipe |
|
self.max_length = max_length |
|
self.batch_size = batch_size |
|
self.save_every_step = save_every_step |
|
self.save_filename = save_filename |
|
self.text_key = text_key |
|
self.replace_chinese_puncts = replace_chinese_puncts |
|
self.verbose = verbose |
|
|
|
if max_length == None and hasattr(pipe.model.config, "max_length"): |
|
self.max_length = pipe.model.config.max_length |
|
|
|
def _is_chinese(self, text: str) -> bool: |
|
return ( |
|
re.search( |
|
r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002ebef\U00030000-\U000323af\ufa0e\ufa0f\ufa11\ufa13\ufa14\ufa1f\ufa21\ufa23\ufa24\ufa27\ufa28\ufa29\u3006\u3007][\ufe00-\ufe0f\U000e0100-\U000e01ef]?", |
|
text, |
|
) |
|
is not None |
|
) |
|
|
|
def _split_sentences(self, text: str) -> List[str]: |
|
tokens = self.pipe.tokenizer(text, add_special_tokens=False) |
|
token_size = len(tokens.input_ids) |
|
|
|
if len(text) <= self.max_length: |
|
return [text] |
|
|
|
delimiter = set() |
|
delimiter.update("。!?;…!?;") |
|
sent_list = [] |
|
sent = text |
|
|
|
while token_size > self.max_length: |
|
orig_sent_len = token_size |
|
|
|
|
|
for i in range(token_size - 2, 0, -1): |
|
token = tokens.token_to_chars(0, i) |
|
char = sent[token.start : token.end] |
|
|
|
if char in delimiter: |
|
split_char_index = token.end |
|
next_sent = sent[split_char_index:] |
|
|
|
if len(next_sent) == 1: |
|
continue |
|
|
|
sent_list = [next_sent] + sent_list |
|
sent = sent[0:split_char_index] |
|
break |
|
|
|
tokens = self.pipe.tokenizer(sent, add_special_tokens=False) |
|
token_size = len(tokens.input_ids) |
|
|
|
|
|
if token_size == orig_sent_len: |
|
sent_list = [sent] + sent_list |
|
sent = "" |
|
break |
|
|
|
if len(sent) > 0: |
|
sent_list = [sent] + sent_list |
|
|
|
return sent_list |
|
|
|
def _preprocess(self, text: str) -> (str, str): |
|
|
|
tags = re.findall(TAG_REGEX, text) |
|
handles = re.findall(HANDLE_REGEX, text) |
|
urls = re.findall(URL_REGEX, text) |
|
emails = re.findall(EMAIL_REGEX, text) |
|
entities = urls + emails + tags + handles |
|
|
|
|
|
text = re.sub(r"$e\[(\d+)\]", r"$E[\1]", text) |
|
|
|
for i, entity in enumerate(entities): |
|
text = text.replace(entity, "$e[%d]" % i, 1) |
|
|
|
lines = text.split("\n") |
|
sentences = [] |
|
num_tokens = [] |
|
template = text.replace("{", "{{").replace("}", "}}") |
|
chunk_index = 0 |
|
|
|
for line in lines: |
|
sentence = line.strip() |
|
|
|
if len(sentence) > 0 and self._is_chinese(sentence): |
|
chunks = self._split_sentences(sentence) |
|
|
|
for chunk in chunks: |
|
sentences.append(chunk) |
|
tokens = self.pipe.tokenizer(chunk, add_special_tokens=False) |
|
num_tokens.append(len(tokens.input_ids)) |
|
chunk = chunk.replace("{", "{{").replace("}", "}}") |
|
template = template.replace(chunk, "{%d}" % chunk_index, 1) |
|
chunk_index += 1 |
|
|
|
return sentences, template, num_tokens, entities |
|
|
|
def _postprocess( |
|
self, |
|
template: str, |
|
src_sentences: List[str], |
|
translations: List[str], |
|
entities: List[str], |
|
) -> str: |
|
processed = [] |
|
alphanumeric_regex = re.compile( |
|
"([a-zA-Za-zA-Z0-9\d+'\",,(\()\)::;;“”。·\.\??\!!‘’$\[\]<>/]+)" |
|
) |
|
|
|
def hash_text(text: List[str]) -> str: |
|
text = "|".join(text) |
|
puncts_map = str.maketrans(",;:()。?!“”‘’", ",;:().?!\"\"''") |
|
text = text.translate(puncts_map) |
|
return text.lower() |
|
|
|
for i, p in enumerate(translations): |
|
src_sentence = src_sentences[i] |
|
|
|
if self.replace_chinese_puncts: |
|
p = re.sub(",", ",", p) |
|
p = re.sub(";", ";", p) |
|
p = re.sub(":", ":", p) |
|
p = re.sub("\(", "(", p) |
|
p = re.sub("\)", ")", p) |
|
p = re.sub(r"([\d]),([\d])", r"\1,\2", p) |
|
|
|
src_matches = re.findall(alphanumeric_regex, src_sentence) |
|
tgt_matches = re.findall(alphanumeric_regex, p) |
|
|
|
|
|
if ( |
|
len(src_matches) != len(tgt_matches) |
|
or len(src_matches) == 0 |
|
or len(tgt_matches) == 0 |
|
): |
|
processed.append(p) |
|
continue |
|
|
|
|
|
src_hashes = hash_text(src_matches) |
|
translated_hashes = hash_text(tgt_matches) |
|
|
|
if src_hashes != translated_hashes: |
|
|
|
for j in range(len(src_matches)): |
|
if src_matches[j] != tgt_matches[j]: |
|
p = p.replace(tgt_matches[j], src_matches[j], 1) |
|
|
|
processed.append(p) |
|
|
|
output = template.format(*processed) |
|
|
|
|
|
for i, entity in enumerate(entities): |
|
output = output.replace("$e[%d]" % i, entity, 1) |
|
|
|
|
|
output = re.sub(r"$E\[(\d+)\]", r"$e[\1]", output) |
|
|
|
|
|
output = re.sub(r"([「」()『』《》。,:])\1+", r"\1", output) |
|
|
|
|
|
if "“" in output: |
|
output = re.sub("“", "「", output) |
|
if "”" in output: |
|
output = re.sub("”", "」", output) |
|
|
|
return output |
|
|
|
def _save(self, translations): |
|
with open(self.save_filename, "wb") as f: |
|
pickle.dump(translations, f) |
|
|
|
def __call__(self, inputs: Union[List[str], Dataset]) -> List[str]: |
|
templates = [] |
|
sentences = [] |
|
num_tokens = [] |
|
sentence_indices = [] |
|
outputs = [] |
|
translations = [] |
|
entities_list = [] |
|
resume_from_file = None |
|
|
|
if isinstance(inputs, Dataset): |
|
ds = inputs |
|
else: |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
ds = Dataset.from_list([{"text": text} for text in inputs]) |
|
|
|
for i, text_input in tqdm( |
|
enumerate(ds), total=len(ds), desc="Preprocessing", disable=not self.verbose |
|
): |
|
chunks, template, num_tokens, entities = self._preprocess( |
|
text_input["text"] |
|
) |
|
templates.append(template) |
|
sentence_indices.append([]) |
|
entities_list.append(entities) |
|
|
|
for j, chunk in enumerate(chunks): |
|
sentences.append(chunk) |
|
sentence_indices[len(sentence_indices) - 1].append(len(sentences) - 1) |
|
num_tokens.append(num_tokens[j]) |
|
|
|
if self.save_filename: |
|
resume_from_file = ( |
|
self.save_filename if os.path.isfile(self.save_filename) else None |
|
) |
|
|
|
if resume_from_file != None: |
|
translations = pickle.load(open(resume_from_file, "rb")) |
|
|
|
if self.verbose: |
|
print("translated:", len(translations)) |
|
print("to translate:", len(sentences) - len(translations)) |
|
|
|
if resume_from_file != None: |
|
print( |
|
"Resuming from {}({} records)".format( |
|
resume_from_file, len(translations) |
|
) |
|
) |
|
|
|
ds = Dataset.from_list( |
|
[{"text": text} for text in sentences[len(translations) :]] |
|
) |
|
|
|
max_token_length = max(num_tokens) |
|
|
|
if self.verbose: |
|
print("Max Length:", max_token_length) |
|
|
|
total_records = len(ds) |
|
|
|
if total_records > 0: |
|
step = 0 |
|
|
|
with tqdm( |
|
disable=not self.verbose, desc="Translating", total=total_records |
|
) as pbar: |
|
for out in self.pipe( |
|
KeyDataset(ds, self.text_key), |
|
batch_size=self.batch_size, |
|
max_length=self.max_length, |
|
): |
|
translations.append(out[0]) |
|
|
|
|
|
if ( |
|
step != 0 |
|
and self.save_filename != None |
|
and step % self.save_every_step == 0 |
|
): |
|
self._save(translations) |
|
|
|
step += 1 |
|
|
|
pbar.update(1) |
|
|
|
if self.save_filename != None and total_records > 0: |
|
self._save(translations) |
|
|
|
for i, template in tqdm( |
|
enumerate(templates), |
|
total=len(templates), |
|
desc="Postprocessing", |
|
disable=not self.verbose, |
|
): |
|
try: |
|
src_sentences = [sentences[index] for index in sentence_indices[i]] |
|
tgt_sentences = [ |
|
translations[index]["translation_text"] |
|
for index in sentence_indices[i] |
|
] |
|
output = self._postprocess( |
|
template, src_sentences, tgt_sentences, entities_list[i] |
|
) |
|
outputs.append(output) |
|
except Exception as error: |
|
print(error) |
|
print(template) |
|
traceback.print_exc() |
|
|
|
|
|
return outputs |
|
|
|
|
|
class Object(object): |
|
pass |
|
|
|
|
|
class FakePipe(object): |
|
def __init__(self, max_length: int = 500): |
|
self.model = Object() |
|
self.model.config = Object() |
|
self.model.config.max_length = max_length |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"indiejoseph/bart-translation-zh-yue" |
|
) |
|
|
|
def __call__(self, text: List[str], batch_size: str, max_length: int): |
|
for i in range(len(text)): |
|
sentence = text[i] |
|
|
|
tags = re.findall(TAG_REGEX, sentence) |
|
handles = re.findall(HANDLE_REGEX, sentence) |
|
urls = re.findall(URL_REGEX, sentence) |
|
emails = re.findall(EMAIL_REGEX, sentence) |
|
entities = urls + emails + tags + handles |
|
|
|
for i, entity in enumerate(entities): |
|
sentence = sentence.replace(entity, "$e[%d]" % i, 1) |
|
|
|
if "123" in sentence: |
|
yield [{"translation_text": sentence.replace("123", "123")}] |
|
continue |
|
if "abc" in sentence: |
|
yield [{"translation_text": sentence.replace("abc", "ABC")}] |
|
continue |
|
if "Acetaminophen" in sentence: |
|
yield [ |
|
{ |
|
"translation_text": sentence.replace( |
|
"Acetaminophen", "ACEtaminidien" |
|
) |
|
} |
|
] |
|
continue |
|
yield [{"translation_text": sentence}] |
|
|
|
|
|
if __name__ == "__main__": |
|
fake_pipe = FakePipe(60) |
|
|
|
translator = Translator(fake_pipe, max_length=60, batch_size=2, verbose=True) |
|
|
|
text1 = "对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:" |
|
text2 = """对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人: |
|
|
|
``` |
|
# 设置用于匹配输入的关键字,并定义相应的回答数据字典。 |
|
keywords = {'你好': '你好!很高兴见到你。', |
|
'再见': '再见!有机会再聊。', |
|
'你叫什么': '我是一个聊天机器人。', |
|
'你是谁': '我是一个基于人工智能技术制作的聊天机器人。'} |
|
|
|
# 定义用于处理用户输入的函数。 |
|
def chatbot(input_text): |
|
# 遍历关键字数据字典,匹配用户的输入。 |
|
for key in keywords: |
|
if key in input_text: |
|
# 如果匹配到了关键字,返回相应的回答。 |
|
return keywords[key] |
|
# 如果没有找到匹配的关键字,返回默认回答。 |
|
return "对不起,我不知道你在说什么。" |
|
|
|
# 运行聊天机器人。 |
|
while True: |
|
# 获取用户输入。 |
|
user_input = input('用户: ') |
|
# 如果用户输入“再见”,退出程序。 |
|
if user_input == '再见': |
|
break |
|
# 处理用户输入,并打印回答。 |
|
print('机器人: ' + chatbot(user_input)) |
|
``` |
|
|
|
这是一个非常简单的例子。对于实用的聊天机器人,可能需要使用更复杂的 NLP 技术和机器学习模型,以更好地理解和回答用户的问题。""" |
|
text3 = "布洛芬(Ibuprofen)同撲熱息痛(Acetaminophen)係兩種常見嘅非處方藥,用於緩解疼痛、發燒同關節痛。" |
|
text4 = "123 “abc” def's http://www.google.com [email protected] @abc 網址:http://localhost/abc下載" |
|
text5 = "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」" |
|
outputs = translator([text1, text2, text3, text4, text5]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert outputs[0] == text1 |
|
assert outputs[1] == text2.replace("“", "「").replace("”", "」") |
|
assert outputs[2] == text3 |
|
assert outputs[3] == text4.replace("“", "「").replace("”", "」") |
|
assert outputs[4] == text5 |
|
|
|
|
|
assert ( |
|
len( |
|
translator._split_sentences( |
|
"新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」" |
|
) |
|
) |
|
== 1 |
|
) |
|
|
|
translator = Translator(fake_pipe, max_length=5, batch_size=2, verbose=True) |
|
|
|
assert ( |
|
len( |
|
translator._split_sentences("====。====。====。====。====。====。====。====。====。") |
|
) |
|
== 9 |
|
) |
|
|