Spaces:
Runtime error
Runtime error
from utils import silent_util | |
import torch | |
import numpy as np | |
from utils import bin_util | |
fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, | |
0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, | |
1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, | |
1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, | |
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0] | |
def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False): | |
# 2.起始bit | |
# start_bit = np.array([0] * len_start_bit) | |
start_bit = fix_pattern[0:len_start_bit] | |
error_prob = 2 ** len_start_bit / 10000 | |
# todo:考虑threshold的时候的错误率呢? | |
if verbose: | |
print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob)) | |
# 3.信息内容 | |
length_msg = num_bit - len(start_bit) | |
if wm_text: | |
msg_arr = bin_util.hexStr2BinArray(wm_text) | |
else: | |
msg_arr = np.random.choice([0, 1], size=length_msg) | |
# 4.封装信息 | |
watermark = np.concatenate([start_bit, msg_arr]) | |
assert len(watermark) == num_bit | |
return start_bit, msg_arr, watermark | |
import time | |
def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False): | |
t1 = time.time() | |
# 1.获得区块大小 | |
chunk_size = num_point + int(num_point * shift_range) | |
output_chunks = [] | |
idx_trunck = -1 | |
for i in range(0, len(data), chunk_size): | |
idx_trunck += 1 | |
current_chunk = data[i:i + chunk_size].copy() | |
# 最后一块,长度不足 | |
if len(current_chunk) < chunk_size: | |
output_chunks.append(current_chunk) | |
break | |
# 处理区块: [水印区|间隔区] | |
current_chunk_cover_area = current_chunk[0:num_point] | |
current_chunk_shift_area = current_chunk[num_point:] | |
current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check, | |
idx_trunck, | |
current_chunk_cover_area, bir_array, | |
device, model) | |
output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area]) | |
assert output.shape == current_chunk.shape | |
output_chunks.append(output) | |
assert len(output_chunks) > 0 | |
reconstructed_array = np.concatenate(output_chunks) | |
time_cost = time.time() - t1 | |
return data, reconstructed_array, time_cost | |
def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model): | |
# 1.判断是否是静音,通过判断子段是否静音来处理 | |
if silence_check and silent_util.is_silent(trunck): | |
print("跳过静音区块:", trunck_idx) | |
return trunck | |
# 2.加入水印 | |
trnck_wmd = encode_trunck(trunck, wm, device, model) | |
return trnck_wmd | |
def encode_trunck(trunck, wm, device, model): | |
with torch.no_grad(): | |
signal = torch.FloatTensor(trunck).to(device)[None] | |
message = torch.FloatTensor(np.array(wm)).to(device)[None] | |
signal_wmd_tensor = model.encode(signal, message) | |
signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze() | |
return signal_wmd | |