diff options
author | 2016-12-02 12:30:57 -0800 | |
---|---|---|
committer | 2016-12-02 12:45:12 -0800 | |
commit | f56c0abfdb1fe0e4812ac490e68cb58a3761586c (patch) | |
tree | 8139f963abdf604953c9e2a6c4da73109cf7fbee | |
parent | 314c5c0d50689936e694cda1ac2731daa2e6a423 (diff) |
Add BatchNormalization layer class and its functional interface.
Change: 140880753
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 62 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization.py | 333 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 479 |
5 files changed, 887 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 9a9ab6e117..12dbf05b4c 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.layers import convolutional as convolutional_layers from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.layers import pooling as pooling_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -462,9 +463,64 @@ def batch_norm( if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') - with variable_scope.variable_scope(scope, 'BatchNorm', [inputs], - reuse=reuse) as sc: + layer_variable_getter = _build_variable_getter() + with variable_scope.variable_scope( + scope, 'BatchNorm', [inputs], reuse=reuse, + custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) + + # Determine whether we can use the core layer class. + if (batch_weights is None and + updates_collections is ops.GraphKeys.UPDATE_OPS): + # Use the core layer class. + axis = 1 if data_format == DATA_FORMAT_NCHW else -1 + if not param_initializers: + param_initializers = {} + beta_initializer = param_initializers.get('beta', + init_ops.zeros_initializer) + gamma_initializer = param_initializers.get('gamma', + init_ops.ones_initializer()) + moving_mean_initializer = param_initializers.get( + 'moving_mean', init_ops.zeros_initializer) + moving_variance_initializer = param_initializers.get( + 'moving_variance', init_ops.ones_initializer()) + layer = normalization_layers.BatchNormalization( + axis=axis, + momentum=decay, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + trainable=trainable, + name=sc.name, + _scope=sc, + _reuse_weights=reuse) + outputs = layer.apply(inputs, training=is_training) + + # Add variables to collections. + _add_variable_to_collections( + layer.moving_mean, variables_collections, 'moving_mean') + _add_variable_to_collections( + layer.moving_variance, variables_collections, 'moving_variance') + if layer.beta: + _add_variable_to_collections(layer.beta, variables_collections, 'beta') + if layer.gamma: + _add_variable_to_collections(layer.beta, variables_collections, 'gamma') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) + + # Not supported by layer class: batch_weights argument, + # and custom updates_collections. In that case, use the legacy BN + # implementation. + # Custom updates collections are not supported because the update logic + # is different in this case, in particular w.r.t. "forced updates" and + # update op reuse. inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: @@ -1230,7 +1286,7 @@ def _model_variable_getter(getter, name, shape=None, dtype=None, custom_getter=getter) -def _build_variable_getter(rename): +def _build_variable_getter(rename=None): """Build a model variable getter that respects scope getter and renames.""" # Respect current getter, if one is set. current_custom_getter = variable_scope.get_variable_scope().custom_getter diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 4715f05c78..befbf33f11 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1570,14 +1570,14 @@ class BatchNormTest(tf.test.TestCase): with tf.Graph().as_default() as g, self.test_session(g): inputs = tf.placeholder(dtype=tf.float32) inputs.set_shape(tf.TensorShape((5, 3, 3, None))) - with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'): + with self.assertRaisesRegexp(ValueError, 'undefined'): tf.contrib.layers.batch_norm(inputs, data_format='NHWC') def testUnknownChannelsDimNCHW(self): with tf.Graph().as_default() as g, self.test_session(g): inputs = tf.placeholder(dtype=tf.float32) inputs.set_shape(tf.TensorShape((5, None, 3, 3))) - with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'): + with self.assertRaisesRegexp(ValueError, 'undefined'): tf.contrib.layers.batch_norm(inputs, data_format='NCHW') def testWeightedMomentsFused(self): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c5c2f77378..a97898d1f5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2363,6 +2363,20 @@ py_test( ], ) +py_test( + name = "layers_normalization_test", + size = "small", + srcs = [ + "layers/normalization_test.py", + ], + main = "layers/normalization_test.py", + srcs_version = "PY2AND3", + deps = [ + ":layers", + "//tensorflow:tensorflow_py", + ], +) + py_library( name = "docs", srcs = [ diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py new file mode 100644 index 0000000000..60e91cd7a0 --- /dev/null +++ b/tensorflow/python/layers/normalization.py @@ -0,0 +1,333 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# pylint: disable=unused-import,g-bad-import-order +"""Contains the normalization layer classes and their functional aliases. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +from six.moves import xrange # pylint: disable=redefined-builtin +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.training import moving_averages +from tensorflow.python.framework import tensor_util + +from tensorflow.python.layers import base + + +class BatchNormalization(base._Layer): # pylint: disable=protected-access + """Batch Normalization layer from http://arxiv.org/abs/1502.03167. + + "Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift" + + Sergey Ioffe, Christian Szegedy + + Arguments: + axis: Integer, the axis that should be normalized (typically the features + axis). For instance, after a `Convolution2D` layer with + `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, subtract `beta`. If False, `beta` is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: A string, the name of the layer. + """ + + def __init__(self, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer=init_ops.zeros_initializer, + gamma_initializer=init_ops.ones_initializer(), + moving_mean_initializer=init_ops.zeros_initializer, + moving_variance_initializer=init_ops.ones_initializer(), + beta_regularizer=None, + gamma_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(BatchNormalization, self).__init__( + name=name, trainable=trainable, **kwargs) + self.axis = axis + self.momentum = momentum + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = beta_initializer + self.gamma_initializer = gamma_initializer + self.moving_mean_initializer = moving_mean_initializer + self.moving_variance_initializer = moving_variance_initializer + self.beta_regularizer = beta_regularizer + self.gamma_regularizer = gamma_regularizer + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if not input_shape.ndims: + raise ValueError('Input has undefined rank:', input_shape) + ndim = len(input_shape) + if self.axis < 0: + axis = ndim + self.axis + else: + axis = self.axis + if axis < 0 or axis >= ndim: + raise ValueError('Value of `axis` argument ' + str(self.axis) + + ' is out of range for input with rank ' + str(ndim)) + param_dim = input_shape[axis] + if not param_dim.value: + raise ValueError('Input has undefined `axis` dimension. Input shape: ', + input_shape) + + if self.center: + self.beta = vs.get_variable('beta', + shape=(param_dim,), + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + trainable=True) + else: + self.beta = None + if self.scale: + self.gamma = vs.get_variable('gamma', + shape=(param_dim,), + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + trainable=True) + else: + self.gamma = None + + # Disable variable partitioning when creating the moving mean and variance + partitioner = vs.get_variable_scope().partitioner + try: + vs.get_variable_scope().set_partitioner(None) + self.moving_mean = vs.get_variable( + 'moving_mean', + shape=(param_dim,), + initializer=self.moving_mean_initializer, + trainable=False) + self.moving_variance = vs.get_variable( + 'moving_variance', + shape=(param_dim,), + initializer=self.moving_variance_initializer, + trainable=False) + finally: + vs.get_variable_scope().set_partitioner(partitioner) + + def call(self, inputs, training=False): + # First, compute the axes along which to reduce the mean / variance, + # as well as the broadcast shape to be used for all parameters. + input_shape = inputs.get_shape() + ndim = len(input_shape) + reduction_axes = list(range(len(input_shape))) + del reduction_axes[self.axis] + broadcast_shape = [1] * len(input_shape) + broadcast_shape[self.axis] = input_shape[self.axis].value + + # Determines whether broadcasting is needed. + needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) + + # Determine boolean training boolean value. May be False, True, None. + # If None, it is assumed that `training` is a variable to be used in `cond`. + if isinstance(training, bool): + training_bool = training + else: + try: + training_bool = tensor_util.constant_value(training) + except TypeError: + training_bool = None + + # Obtain current current batch mean, variance, if necessary. + if training_bool is not False: + # Use a copy of moving_mean as a shift to compute more reliable moments. + shift = math_ops.add(self.moving_mean, 0) + if needs_broadcasting: + shift = array_ops.reshape(shift, broadcast_shape) + broadcast_mean, broadcast_variance = nn.moments( + inputs, reduction_axes, shift=shift, keep_dims=True) + mean = array_ops.reshape(broadcast_mean, [-1]) + variance = array_ops.reshape(broadcast_variance, [-1]) + else: + mean, variance = nn.moments(inputs, reduction_axes, shift=shift) + + # Prepare updates if necessary. + if training_bool is not False and not self.updates: + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, self.momentum, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, self.momentum, zero_debias=False) + # In the future this should be refactored into a self.add_update + # methods in order to allow for instance-based BN layer sharing + # across unrelated input streams (e.g. like in Keras). + self.updates.append(mean_update) + self.updates.append(variance_update) + + # Normalize batch. + if needs_broadcasting: + # In this case we must explictly broadcast all parameters. + broadcast_moving_mean = array_ops.reshape(self.moving_mean, + broadcast_shape) + broadcast_moving_variance = array_ops.reshape(self.moving_variance, + broadcast_shape) + if self.center: + broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) + else: + broadcast_beta = None + if self.scale: + broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) + else: + broadcast_gamma = None + + if training_bool is not False: + normed_inputs_training = nn.batch_normalization(inputs, + broadcast_mean, + broadcast_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) + normed_inputs = nn.batch_normalization(inputs, + broadcast_moving_mean, + broadcast_moving_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) + else: + # No need for broadcasting. + if training_bool is not False: + normed_inputs_training = nn.batch_normalization( + inputs, + mean, + variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) + normed_inputs = nn.batch_normalization(inputs, + self.moving_mean, + self.moving_variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) + + # Return the proper output depending on the boolean training phase. + if training_bool is True: + return normed_inputs_training + if training_bool is False: + return normed_inputs + return control_flow_ops.cond(training, + lambda: normed_inputs_training, + lambda: normed_inputs) + + +def batch_normalization(inputs, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer=init_ops.zeros_initializer, + gamma_initializer=init_ops.ones_initializer(), + moving_mean_initializer=init_ops.zeros_initializer, + moving_variance_initializer=init_ops.ones_initializer(), + beta_regularizer=None, + gamma_regularizer=None, + training=False, + trainable=True, + name=None, + reuse=False): + """Functional interface for the batch normalization layer. + + Reference: http://arxiv.org/abs/1502.03167 + + "Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift" + + Sergey Ioffe, Christian Szegedy + + Arguments: + inputs: Tensor input. + axis: Integer, the axis that should be normalized (typically the features + axis). For instance, after a `Convolution2D` layer with + `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, subtract `beta`. If False, `beta` is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + training: Either a Python boolean, or a TensorFlow boolean scalar tensor + (e.g. a placeholder). Whether to return the output in training mode + (normalized with statistics of the current batch) or in inference mode + (normalized with moving statistics). + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: String, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + """ + layer = BatchNormalization( + axis=axis, + momentum=momentum, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + trainable=trainable, + name=name, + _reuse_weights=reuse, + _scope=name) + return layer.apply(inputs, training=training) + + +# Aliases + +BatchNorm = BatchNormalization +batch_norm = batch_normalization diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py new file mode 100644 index 0000000000..3f5b4fa632 --- /dev/null +++ b/tensorflow/python/layers/normalization_test.py @@ -0,0 +1,479 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.layers.core.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.python.layers import normalization as normalization_layers + + +class BNTest(tf.test.TestCase): + + def testCreateBN(self): + # Call layer. + bn = normalization_layers.BatchNormalization(axis=1) + inputs = tf.random_uniform((5, 4, 3), seed=1) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + # Verify shape. + self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3]) + + # Verify layer attributes. + self.assertEqual(len(bn.updates), 2) + self.assertEqual(len(bn.weights), 4) + self.assertEqual(len(bn.trainable_weights), 2) + self.assertEqual(len(bn.non_trainable_weights), 2) + + # Test that updates were created and added to UPDATE_OPS. + self.assertEqual(len(bn.updates), 2) + self.assertListEqual( + tf.get_collection(tf.GraphKeys.UPDATE_OPS), bn.updates) + + # Test that weights were created and added to TRAINABLE_VARIABLES. + self.assertListEqual( + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), + bn.trainable_weights) + + def test3DInputAxis1(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=1, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2)) + std = np.std(np_inputs, axis=(0, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test3DInputAxis2(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=2, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1)) + std = np.std(np_inputs, axis=(0, 1)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 3)) + np_beta = np.reshape(np_beta, (1, 1, 3)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis1(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=1, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis2(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=2, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 3)) + std = np.std(np_inputs, axis=(0, 1, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 3, 1)) + np_beta = np.reshape(np_beta, (1, 1, 3, 1)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis3(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=3, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def testNegativeAxis(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=-1, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + bn.updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def testBooleanLearningPhase(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization(axis=-1, + epsilon=epsilon, momentum=0.9) + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + outputs_training = bn.apply(inputs, training=True) + outputs_infer = bn.apply(inputs, training=False) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs_training] + bn.updates) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs_infer) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def testFunctionalNoReuse(self): + inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + epsilon = 1e-3 + training = tf.placeholder(dtype='bool') + outputs = normalization_layers.batch_norm( + inputs, axis=-1, momentum=0.9, epsilon=epsilon, + training=training, name='bn') + + updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + all_vars = dict([(v.name, v) for v in tf.global_variables()]) + moving_mean = all_vars['bn/moving_mean:0'] + moving_variance = all_vars['bn/moving_variance:0'] + beta = all_vars['bn/beta:0'] + gamma = all_vars['bn/gamma:0'] + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs] + updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance]) + np_inputs = sess.run(inputs) + np_mean = np.mean(np_inputs, axis=(0, 1, 2)) + np_std = np.std(np_inputs, axis=(0, 1, 2)) + np_variance = np.square(np_std) + self.assertAllClose(np_mean, np_moving_mean, atol=1e-2) + self.assertAllClose(np_variance, np_moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([gamma, beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def testFunctionalReuse(self): + inputs1 = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + inputs2 = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32) + epsilon = 1e-3 + training = tf.placeholder(dtype='bool') + _ = normalization_layers.batch_norm( + inputs1, axis=-1, momentum=0.9, epsilon=epsilon, + training=training, name='bn') + outputs2 = normalization_layers.batch_norm( + inputs2, axis=-1, momentum=0.9, epsilon=epsilon, + training=training, name='bn', reuse=True) + + # Last 2 update ops + updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)[-2:] + all_vars = dict([(v.name, v) for v in tf.global_variables()]) + moving_mean = all_vars['bn/moving_mean:0'] + moving_variance = all_vars['bn/moving_variance:0'] + beta = all_vars['bn/beta:0'] + gamma = all_vars['bn/gamma:0'] + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(tf.global_variables_initializer()) + for _ in range(100): + np_output, _, _ = sess.run([outputs2] + updates, + feed_dict={training: True}) + + # Verify that the statistics are updated during training. + np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance]) + np_inputs = sess.run(inputs2) + np_mean = np.mean(np_inputs, axis=(0, 1, 2)) + np_std = np.std(np_inputs, axis=(0, 1, 2)) + np_variance = np.square(np_std) + self.assertAllClose(np_mean, np_moving_mean, atol=1e-2) + self.assertAllClose(np_variance, np_moving_var, atol=1e-2) + + # Verify that the axis is normalized during training. + np_gamma, np_beta = sess.run([gamma, beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs2, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def testNoCenter(self): + bn = normalization_layers.BatchNormalization(axis=1, center=False) + inputs = tf.random_uniform((5, 4, 3), seed=1) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + # Verify shape. + self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3]) + + # Verify layer attributes. + self.assertEqual(len(bn.updates), 2) + self.assertEqual(len(bn.weights), 3) + self.assertEqual(len(bn.trainable_weights), 1) + self.assertEqual(len(bn.non_trainable_weights), 2) + + def testNoScale(self): + bn = normalization_layers.BatchNormalization(axis=1, scale=False) + inputs = tf.random_uniform((5, 4, 3), seed=1) + training = tf.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + # Verify shape. + self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3]) + + # Verify layer attributes. + self.assertEqual(len(bn.updates), 2) + self.assertEqual(len(bn.weights), 3) + self.assertEqual(len(bn.trainable_weights), 1) + self.assertEqual(len(bn.non_trainable_weights), 2) + + def testRegularizers(self): + reg = lambda x: 0.1 * tf.reduce_sum(x) + bn = normalization_layers.BatchNormalization(axis=1, beta_regularizer=reg) + inputs = tf.random_uniform((5, 4, 3), seed=1) + training = tf.placeholder(dtype='bool') + _ = bn.apply(inputs, training=training) + self.assertEqual(len(bn.losses), 1) + + bn = normalization_layers.BatchNormalization(axis=1, gamma_regularizer=reg) + inputs = tf.random_uniform((5, 4, 3), seed=1) + training = tf.placeholder(dtype='bool') + _ = bn.apply(inputs, training=training) + self.assertEqual(len(bn.losses), 1) + + +if __name__ == '__main__': + tf.test.main() |