|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Defines the 'VGGish' model used to generate AudioSet embedding features. |
|
|
|
The public AudioSet release (https://research.google.com/audioset/download.html) |
|
includes 128-D features extracted from the embedding layer of a VGG-like model |
|
that was trained on a large Google-internal YouTube dataset. Here we provide |
|
a TF-Slim definition of the same model, without any dependences on libraries |
|
internal to Google. We call it 'VGGish'. |
|
|
|
Note that we only define the model up to the embedding layer, which is the |
|
penultimate layer before the final classifier layer. We also provide various |
|
hyperparameter values (in vggish_params.py) that were used to train this model |
|
internally. |
|
|
|
For comparison, here is TF-Slim's VGG definition: |
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py |
|
""" |
|
|
|
import tensorflow._api.v2.compat.v1 as tf |
|
tf.disable_v2_behavior() |
|
import tf_slim as slim |
|
|
|
import vggish_params as params |
|
|
|
|
|
def define_vggish_slim(training=False): |
|
"""Defines the VGGish TensorFlow model. |
|
|
|
All ops are created in the current default graph, under the scope 'vggish/'. |
|
|
|
The input is a placeholder named 'vggish/input_features' of type float32 and |
|
shape [batch_size, num_frames, num_bands] where batch_size is variable and |
|
num_frames and num_bands are constants, and [num_frames, num_bands] represents |
|
a log-mel-scale spectrogram patch covering num_bands frequency bands and |
|
num_frames time frames (where each frame step is usually 10ms). This is |
|
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET). |
|
The output is an op named 'vggish/embedding' which produces the activations of |
|
a 128-D embedding layer, which is usually the penultimate layer when used as |
|
part of a full model with a final classifier layer. |
|
|
|
Args: |
|
training: If true, all parameters are marked trainable. |
|
|
|
Returns: |
|
The op 'vggish/embeddings'. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
with slim.arg_scope([slim.conv2d, slim.fully_connected], |
|
weights_initializer=tf.truncated_normal_initializer( |
|
stddev=params.INIT_STDDEV), |
|
biases_initializer=tf.zeros_initializer(), |
|
activation_fn=tf.nn.relu, |
|
trainable=training), \ |
|
slim.arg_scope([slim.conv2d], |
|
kernel_size=[3, 3], stride=1, padding='SAME'), \ |
|
slim.arg_scope([slim.max_pool2d], |
|
kernel_size=[2, 2], stride=2, padding='SAME'), \ |
|
tf.compat.v1.variable_scope('vggish'): |
|
|
|
|
|
|
|
features = tf.compat.v1.placeholder( |
|
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS), |
|
name='input_features') |
|
|
|
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1]) |
|
|
|
|
|
net = slim.conv2d(net, 64, scope='conv1') |
|
net = slim.max_pool2d(net, scope='pool1') |
|
net = slim.conv2d(net, 128, scope='conv2') |
|
net = slim.max_pool2d(net, scope='pool2') |
|
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3') |
|
net = slim.max_pool2d(net, scope='pool3') |
|
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4') |
|
net = slim.max_pool2d(net, scope='pool4') |
|
|
|
|
|
net = slim.flatten(net) |
|
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1') |
|
|
|
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2') |
|
return tf.identity(net, name='embedding') |
|
|
|
|
|
def load_vggish_slim_checkpoint(session, checkpoint_path): |
|
"""Loads a pre-trained VGGish-compatible checkpoint. |
|
|
|
This function can be used as an initialization function (referred to as |
|
init_fn in TensorFlow documentation) which is called in a Session after |
|
initializating all variables. When used as an init_fn, this will load |
|
a pre-trained checkpoint that is compatible with the VGGish model |
|
definition. Only variables defined by VGGish will be loaded. |
|
|
|
Args: |
|
session: an active TensorFlow session. |
|
checkpoint_path: path to a file containing a checkpoint that is |
|
compatible with the VGGish model definition. |
|
""" |
|
|
|
|
|
with tf.Graph().as_default(): |
|
define_vggish_slim(training=False) |
|
|
|
vggish_var_names = [v.name for v in tf.compat.v1.global_variables()] |
|
|
|
|
|
|
|
|
|
vggish_vars = [v for v in tf.compat.v1.global_variables() if v.name in vggish_var_names] |
|
|
|
|
|
|
|
saver = tf.compat.v1.train.Saver(vggish_vars, name='vggish_load_pretrained') |
|
saver.restore(session, checkpoint_path) |
|
|