|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import torch |
|
import torch.nn as nn |
|
from unet import UNet |
|
|
|
|
|
def load_graph(frozen_graph_filename): |
|
|
|
|
|
|
|
|
|
|
|
with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: |
|
graph_def = tf.compat.v1.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
|
|
|
|
with tf.Graph().as_default() as graph: |
|
|
|
|
|
|
|
tf.import_graph_def(graph_def, name="") |
|
return graph |
|
|
|
|
|
def generate_waveform(): |
|
np.random.seed(20230821) |
|
waveform = np.random.rand(60 * 44100).astype(np.float32) |
|
|
|
|
|
waveform = waveform.reshape(-1, 2) |
|
return waveform |
|
|
|
|
|
def get_param(graph, name): |
|
with tf.compat.v1.Session(graph=graph) as sess: |
|
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] |
|
for constant_op in constant_ops: |
|
if constant_op.name != name: |
|
continue |
|
|
|
value = sess.run(constant_op.outputs[0]) |
|
return torch.from_numpy(value) |
|
|
|
|
|
@torch.no_grad() |
|
def main(name): |
|
graph = load_graph(f"./2stems/frozen_{name}_model.pb") |
|
|
|
|
|
x = graph.get_tensor_by_name("waveform:0") |
|
|
|
y0 = graph.get_tensor_by_name("strided_slice_3:0") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0") |
|
|
|
unet = UNet() |
|
unet.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = unet.state_dict() |
|
|
|
|
|
if name == "vocals": |
|
state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute( |
|
3, 2, 0, 1 |
|
) |
|
state_dict["conv.bias"] = get_param(graph, "conv2d/bias") |
|
|
|
state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma") |
|
state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta") |
|
state_dict["bn.running_mean"] = get_param( |
|
graph, "batch_normalization/moving_mean" |
|
) |
|
state_dict["bn.running_var"] = get_param( |
|
graph, "batch_normalization/moving_variance" |
|
) |
|
|
|
conv_offset = 0 |
|
bn_offset = 0 |
|
else: |
|
state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute( |
|
3, 2, 0, 1 |
|
) |
|
state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias") |
|
|
|
state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma") |
|
state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta") |
|
state_dict["bn.running_mean"] = get_param( |
|
graph, "batch_normalization_12/moving_mean" |
|
) |
|
state_dict["bn.running_var"] = get_param( |
|
graph, "batch_normalization_12/moving_variance" |
|
) |
|
conv_offset = 7 |
|
bn_offset = 12 |
|
|
|
for i in range(1, 6): |
|
state_dict[f"conv{i}.weight"] = get_param( |
|
graph, f"conv2d_{i+conv_offset}/kernel" |
|
).permute(3, 2, 0, 1) |
|
state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias") |
|
if i >= 5: |
|
continue |
|
state_dict[f"bn{i}.weight"] = get_param( |
|
graph, f"batch_normalization_{i+bn_offset}/gamma" |
|
) |
|
state_dict[f"bn{i}.bias"] = get_param( |
|
graph, f"batch_normalization_{i+bn_offset}/beta" |
|
) |
|
state_dict[f"bn{i}.running_mean"] = get_param( |
|
graph, f"batch_normalization_{i+bn_offset}/moving_mean" |
|
) |
|
state_dict[f"bn{i}.running_var"] = get_param( |
|
graph, f"batch_normalization_{i+bn_offset}/moving_variance" |
|
) |
|
|
|
if name == "vocals": |
|
state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute( |
|
3, 2, 0, 1 |
|
) |
|
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias") |
|
|
|
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma") |
|
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta") |
|
state_dict["bn5.running_mean"] = get_param( |
|
graph, "batch_normalization_6/moving_mean" |
|
) |
|
state_dict["bn5.running_var"] = get_param( |
|
graph, "batch_normalization_6/moving_variance" |
|
) |
|
conv_offset = 0 |
|
bn_offset = 0 |
|
else: |
|
state_dict["up1.weight"] = get_param( |
|
graph, "conv2d_transpose_6/kernel" |
|
).permute(3, 2, 0, 1) |
|
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias") |
|
|
|
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma") |
|
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta") |
|
state_dict["bn5.running_mean"] = get_param( |
|
graph, "batch_normalization_18/moving_mean" |
|
) |
|
state_dict["bn5.running_var"] = get_param( |
|
graph, "batch_normalization_18/moving_variance" |
|
) |
|
conv_offset = 6 |
|
bn_offset = 12 |
|
|
|
for i in range(1, 6): |
|
state_dict[f"up{i+1}.weight"] = get_param( |
|
graph, f"conv2d_transpose_{i+conv_offset}/kernel" |
|
).permute(3, 2, 0, 1) |
|
|
|
state_dict[f"up{i+1}.bias"] = get_param( |
|
graph, f"conv2d_transpose_{i+conv_offset}/bias" |
|
) |
|
|
|
state_dict[f"bn{5+i}.weight"] = get_param( |
|
graph, f"batch_normalization_{6+i+bn_offset}/gamma" |
|
) |
|
state_dict[f"bn{5+i}.bias"] = get_param( |
|
graph, f"batch_normalization_{6+i+bn_offset}/beta" |
|
) |
|
state_dict[f"bn{5+i}.running_mean"] = get_param( |
|
graph, f"batch_normalization_{6+i+bn_offset}/moving_mean" |
|
) |
|
state_dict[f"bn{5+i}.running_var"] = get_param( |
|
graph, f"batch_normalization_{6+i+bn_offset}/moving_variance" |
|
) |
|
|
|
if name == "vocals": |
|
state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute( |
|
3, 2, 0, 1 |
|
) |
|
state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias") |
|
else: |
|
state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute( |
|
3, 2, 0, 1 |
|
) |
|
state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias") |
|
|
|
unet.load_state_dict(state_dict) |
|
|
|
with tf.compat.v1.Session(graph=graph) as sess: |
|
y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) |
|
|
|
|
|
assert torch.allclose( |
|
torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1 |
|
), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max()) |
|
torch.save(unet.state_dict(), f"2stems/{name}.pt") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--name", |
|
type=str, |
|
required=True, |
|
choices=["vocals", "accompaniment"], |
|
) |
|
args = parser.parse_args() |
|
print(vars(args)) |
|
main(args.name) |
|
|