Skip to content Skip to sidebar Skip to footer

Migrate Tf.contrib.layers.batch_norm To Tensorflow 2.0

I'm migrating a TensorFlow code to Tensorflow 2.1.0. Here is the original code: conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME') conv = tf.contrib.layer

Solution 1:

I encountered this problem when working with a trained model that I was going to fine tune. Just replacing tf.contrib.layers.batch_norm with tf.keras.layers.BatchNormalization like OP did gave me an error whose fix is described below.

The old code looked like this:

tf.contrib.layers.batch_norm(
    tensor,
    scale=True,
    center=True,
    is_training=self.use_batch_statistics,
    trainable=True,
    data_format=self._data_format,
    updates_collections=None,
)

and the updated working code looks like this:

tf.keras.layers.BatchNormalization(
    name="BatchNorm",
    scale=True,
    center=True,
    trainable=True,
)(tensor)

I'm unsure if all the keyword arguments I removed are going to be a problem but everything seems to work. Note the name="BatchNorm" argument. The layers use a different naming schema so I had to use the inspect_checkpoint.py tool to look at the model and find the layer names which happened to be BatchNorm.

Post a Comment for "Migrate Tf.contrib.layers.batch_norm To Tensorflow 2.0"