aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2016-12-02 12:30:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-02 12:45:12 -0800
commitf56c0abfdb1fe0e4812ac490e68cb58a3761586c (patch)
tree8139f963abdf604953c9e2a6c4da73109cf7fbee
parent314c5c0d50689936e694cda1ac2731daa2e6a423 (diff)
Add BatchNormalization layer class and its functional interface.
Change: 140880753
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py62
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py4
-rw-r--r--tensorflow/python/BUILD14
-rw-r--r--tensorflow/python/layers/normalization.py333
-rw-r--r--tensorflow/python/layers/normalization_test.py479
5 files changed, 887 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 9a9ab6e117..12dbf05b4c 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.layers import convolutional as convolutional_layers
from tensorflow.python.layers import core as core_layers
+from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.layers import pooling as pooling_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -462,9 +463,64 @@ def batch_norm(
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
- with variable_scope.variable_scope(scope, 'BatchNorm', [inputs],
- reuse=reuse) as sc:
+ layer_variable_getter = _build_variable_getter()
+ with variable_scope.variable_scope(
+ scope, 'BatchNorm', [inputs], reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
+
+ # Determine whether we can use the core layer class.
+ if (batch_weights is None and
+ updates_collections is ops.GraphKeys.UPDATE_OPS):
+ # Use the core layer class.
+ axis = 1 if data_format == DATA_FORMAT_NCHW else -1
+ if not param_initializers:
+ param_initializers = {}
+ beta_initializer = param_initializers.get('beta',
+ init_ops.zeros_initializer)
+ gamma_initializer = param_initializers.get('gamma',
+ init_ops.ones_initializer())
+ moving_mean_initializer = param_initializers.get(
+ 'moving_mean', init_ops.zeros_initializer)
+ moving_variance_initializer = param_initializers.get(
+ 'moving_variance', init_ops.ones_initializer())
+ layer = normalization_layers.BatchNormalization(
+ axis=axis,
+ momentum=decay,
+ epsilon=epsilon,
+ center=center,
+ scale=scale,
+ beta_initializer=beta_initializer,
+ gamma_initializer=gamma_initializer,
+ moving_mean_initializer=moving_mean_initializer,
+ moving_variance_initializer=moving_variance_initializer,
+ trainable=trainable,
+ name=sc.name,
+ _scope=sc,
+ _reuse_weights=reuse)
+ outputs = layer.apply(inputs, training=is_training)
+
+ # Add variables to collections.
+ _add_variable_to_collections(
+ layer.moving_mean, variables_collections, 'moving_mean')
+ _add_variable_to_collections(
+ layer.moving_variance, variables_collections, 'moving_variance')
+ if layer.beta:
+ _add_variable_to_collections(layer.beta, variables_collections, 'beta')
+ if layer.gamma:
+ _add_variable_to_collections(layer.beta, variables_collections, 'gamma')
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections,
+ sc.original_name_scope, outputs)
+
+ # Not supported by layer class: batch_weights argument,
+ # and custom updates_collections. In that case, use the legacy BN
+ # implementation.
+ # Custom updates collections are not supported because the update logic
+ # is different in this case, in particular w.r.t. "forced updates" and
+ # update op reuse.
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
@@ -1230,7 +1286,7 @@ def _model_variable_getter(getter, name, shape=None, dtype=None,
custom_getter=getter)
-def _build_variable_getter(rename):
+def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""
# Respect current getter, if one is set.
current_custom_getter = variable_scope.get_variable_scope().custom_getter
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 4715f05c78..befbf33f11 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1570,14 +1570,14 @@ class BatchNormTest(tf.test.TestCase):
with tf.Graph().as_default() as g, self.test_session(g):
inputs = tf.placeholder(dtype=tf.float32)
inputs.set_shape(tf.TensorShape((5, 3, 3, None)))
- with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
+ with self.assertRaisesRegexp(ValueError, 'undefined'):
tf.contrib.layers.batch_norm(inputs, data_format='NHWC')
def testUnknownChannelsDimNCHW(self):
with tf.Graph().as_default() as g, self.test_session(g):
inputs = tf.placeholder(dtype=tf.float32)
inputs.set_shape(tf.TensorShape((5, None, 3, 3)))
- with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
+ with self.assertRaisesRegexp(ValueError, 'undefined'):
tf.contrib.layers.batch_norm(inputs, data_format='NCHW')
def testWeightedMomentsFused(self):
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c5c2f77378..a97898d1f5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2363,6 +2363,20 @@ py_test(
],
)
+py_test(
+ name = "layers_normalization_test",
+ size = "small",
+ srcs = [
+ "layers/normalization_test.py",
+ ],
+ main = "layers/normalization_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":layers",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
py_library(
name = "docs",
srcs = [
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
new file mode 100644
index 0000000000..60e91cd7a0
--- /dev/null
+++ b/tensorflow/python/layers/normalization.py
@@ -0,0 +1,333 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=unused-import,g-bad-import-order
+"""Contains the normalization layer classes and their functional aliases.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.training import moving_averages
+from tensorflow.python.framework import tensor_util
+
+from tensorflow.python.layers import base
+
+
+class BatchNormalization(base._Layer): # pylint: disable=protected-access
+ """Batch Normalization layer from http://arxiv.org/abs/1502.03167.
+
+ "Batch Normalization: Accelerating Deep Network Training by Reducing
+ Internal Covariate Shift"
+
+ Sergey Ioffe, Christian Szegedy
+
+ Arguments:
+ axis: Integer, the axis that should be normalized (typically the features
+ axis). For instance, after a `Convolution2D` layer with
+ `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
+ momentum: Momentum for the moving average.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ center: If True, subtract `beta`. If False, `beta` is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling can be done by the next layer.
+ beta_initializer: Initializer for the beta weight.
+ gamma_initializer: Initializer for the gamma weight.
+ moving_mean_initializer: Initializer for the moving mean.
+ moving_variance_initializer: Initializer for the moving variance.
+ beta_regularizer: Optional regularizer for the beta weight.
+ gamma_regularizer: Optional regularizer for the gamma weight.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ axis=-1,
+ momentum=0.99,
+ epsilon=1e-3,
+ center=True,
+ scale=True,
+ beta_initializer=init_ops.zeros_initializer,
+ gamma_initializer=init_ops.ones_initializer(),
+ moving_mean_initializer=init_ops.zeros_initializer,
+ moving_variance_initializer=init_ops.ones_initializer(),
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(BatchNormalization, self).__init__(
+ name=name, trainable=trainable, **kwargs)
+ self.axis = axis
+ self.momentum = momentum
+ self.epsilon = epsilon
+ self.center = center
+ self.scale = scale
+ self.beta_initializer = beta_initializer
+ self.gamma_initializer = gamma_initializer
+ self.moving_mean_initializer = moving_mean_initializer
+ self.moving_variance_initializer = moving_variance_initializer
+ self.beta_regularizer = beta_regularizer
+ self.gamma_regularizer = gamma_regularizer
+
+ def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if not input_shape.ndims:
+ raise ValueError('Input has undefined rank:', input_shape)
+ ndim = len(input_shape)
+ if self.axis < 0:
+ axis = ndim + self.axis
+ else:
+ axis = self.axis
+ if axis < 0 or axis >= ndim:
+ raise ValueError('Value of `axis` argument ' + str(self.axis) +
+ ' is out of range for input with rank ' + str(ndim))
+ param_dim = input_shape[axis]
+ if not param_dim.value:
+ raise ValueError('Input has undefined `axis` dimension. Input shape: ',
+ input_shape)
+
+ if self.center:
+ self.beta = vs.get_variable('beta',
+ shape=(param_dim,),
+ initializer=self.beta_initializer,
+ regularizer=self.beta_regularizer,
+ trainable=True)
+ else:
+ self.beta = None
+ if self.scale:
+ self.gamma = vs.get_variable('gamma',
+ shape=(param_dim,),
+ initializer=self.gamma_initializer,
+ regularizer=self.gamma_regularizer,
+ trainable=True)
+ else:
+ self.gamma = None
+
+ # Disable variable partitioning when creating the moving mean and variance
+ partitioner = vs.get_variable_scope().partitioner
+ try:
+ vs.get_variable_scope().set_partitioner(None)
+ self.moving_mean = vs.get_variable(
+ 'moving_mean',
+ shape=(param_dim,),
+ initializer=self.moving_mean_initializer,
+ trainable=False)
+ self.moving_variance = vs.get_variable(
+ 'moving_variance',
+ shape=(param_dim,),
+ initializer=self.moving_variance_initializer,
+ trainable=False)
+ finally:
+ vs.get_variable_scope().set_partitioner(partitioner)
+
+ def call(self, inputs, training=False):
+ # First, compute the axes along which to reduce the mean / variance,
+ # as well as the broadcast shape to be used for all parameters.
+ input_shape = inputs.get_shape()
+ ndim = len(input_shape)
+ reduction_axes = list(range(len(input_shape)))
+ del reduction_axes[self.axis]
+ broadcast_shape = [1] * len(input_shape)
+ broadcast_shape[self.axis] = input_shape[self.axis].value
+
+ # Determines whether broadcasting is needed.
+ needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
+
+ # Determine boolean training boolean value. May be False, True, None.
+ # If None, it is assumed that `training` is a variable to be used in `cond`.
+ if isinstance(training, bool):
+ training_bool = training
+ else:
+ try:
+ training_bool = tensor_util.constant_value(training)
+ except TypeError:
+ training_bool = None
+
+ # Obtain current current batch mean, variance, if necessary.
+ if training_bool is not False:
+ # Use a copy of moving_mean as a shift to compute more reliable moments.
+ shift = math_ops.add(self.moving_mean, 0)
+ if needs_broadcasting:
+ shift = array_ops.reshape(shift, broadcast_shape)
+ broadcast_mean, broadcast_variance = nn.moments(
+ inputs, reduction_axes, shift=shift, keep_dims=True)
+ mean = array_ops.reshape(broadcast_mean, [-1])
+ variance = array_ops.reshape(broadcast_variance, [-1])
+ else:
+ mean, variance = nn.moments(inputs, reduction_axes, shift=shift)
+
+ # Prepare updates if necessary.
+ if training_bool is not False and not self.updates:
+ mean_update = moving_averages.assign_moving_average(
+ self.moving_mean, mean, self.momentum, zero_debias=False)
+ variance_update = moving_averages.assign_moving_average(
+ self.moving_variance, variance, self.momentum, zero_debias=False)
+ # In the future this should be refactored into a self.add_update
+ # methods in order to allow for instance-based BN layer sharing
+ # across unrelated input streams (e.g. like in Keras).
+ self.updates.append(mean_update)
+ self.updates.append(variance_update)
+
+ # Normalize batch.
+ if needs_broadcasting:
+ # In this case we must explictly broadcast all parameters.
+ broadcast_moving_mean = array_ops.reshape(self.moving_mean,
+ broadcast_shape)
+ broadcast_moving_variance = array_ops.reshape(self.moving_variance,
+ broadcast_shape)
+ if self.center:
+ broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
+ else:
+ broadcast_beta = None
+ if self.scale:
+ broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
+ else:
+ broadcast_gamma = None
+
+ if training_bool is not False:
+ normed_inputs_training = nn.batch_normalization(inputs,
+ broadcast_mean,
+ broadcast_variance,
+ broadcast_beta,
+ broadcast_gamma,
+ self.epsilon)
+ normed_inputs = nn.batch_normalization(inputs,
+ broadcast_moving_mean,
+ broadcast_moving_variance,
+ broadcast_beta,
+ broadcast_gamma,
+ self.epsilon)
+ else:
+ # No need for broadcasting.
+ if training_bool is not False:
+ normed_inputs_training = nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ self.beta if self.center else None,
+ self.gamma if self.scale else None,
+ self.epsilon)
+ normed_inputs = nn.batch_normalization(inputs,
+ self.moving_mean,
+ self.moving_variance,
+ self.beta if self.center else None,
+ self.gamma if self.scale else None,
+ self.epsilon)
+
+ # Return the proper output depending on the boolean training phase.
+ if training_bool is True:
+ return normed_inputs_training
+ if training_bool is False:
+ return normed_inputs
+ return control_flow_ops.cond(training,
+ lambda: normed_inputs_training,
+ lambda: normed_inputs)
+
+
+def batch_normalization(inputs,
+ axis=-1,
+ momentum=0.99,
+ epsilon=1e-3,
+ center=True,
+ scale=True,
+ beta_initializer=init_ops.zeros_initializer,
+ gamma_initializer=init_ops.ones_initializer(),
+ moving_mean_initializer=init_ops.zeros_initializer,
+ moving_variance_initializer=init_ops.ones_initializer(),
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ training=False,
+ trainable=True,
+ name=None,
+ reuse=False):
+ """Functional interface for the batch normalization layer.
+
+ Reference: http://arxiv.org/abs/1502.03167
+
+ "Batch Normalization: Accelerating Deep Network Training by Reducing
+ Internal Covariate Shift"
+
+ Sergey Ioffe, Christian Szegedy
+
+ Arguments:
+ inputs: Tensor input.
+ axis: Integer, the axis that should be normalized (typically the features
+ axis). For instance, after a `Convolution2D` layer with
+ `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
+ momentum: Momentum for the moving average.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ center: If True, subtract `beta`. If False, `beta` is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling can be done by the next layer.
+ beta_initializer: Initializer for the beta weight.
+ gamma_initializer: Initializer for the gamma weight.
+ moving_mean_initializer: Initializer for the moving mean.
+ moving_variance_initializer: Initializer for the moving variance.
+ beta_regularizer: Optional regularizer for the beta weight.
+ gamma_regularizer: Optional regularizer for the gamma weight.
+ training: Either a Python boolean, or a TensorFlow boolean scalar tensor
+ (e.g. a placeholder). Whether to return the output in training mode
+ (normalized with statistics of the current batch) or in inference mode
+ (normalized with moving statistics).
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ name: String, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+ """
+ layer = BatchNormalization(
+ axis=axis,
+ momentum=momentum,
+ epsilon=epsilon,
+ center=center,
+ scale=scale,
+ beta_initializer=beta_initializer,
+ gamma_initializer=gamma_initializer,
+ moving_mean_initializer=moving_mean_initializer,
+ moving_variance_initializer=moving_variance_initializer,
+ beta_regularizer=beta_regularizer,
+ gamma_regularizer=gamma_regularizer,
+ trainable=trainable,
+ name=name,
+ _reuse_weights=reuse,
+ _scope=name)
+ return layer.apply(inputs, training=training)
+
+
+# Aliases
+
+BatchNorm = BatchNormalization
+batch_norm = batch_normalization
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
new file mode 100644
index 0000000000..3f5b4fa632
--- /dev/null
+++ b/tensorflow/python/layers/normalization_test.py
@@ -0,0 +1,479 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.layers.core."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.layers import normalization as normalization_layers
+
+
+class BNTest(tf.test.TestCase):
+
+ def testCreateBN(self):
+ # Call layer.
+ bn = normalization_layers.BatchNormalization(axis=1)
+ inputs = tf.random_uniform((5, 4, 3), seed=1)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ # Verify shape.
+ self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])
+
+ # Verify layer attributes.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertEqual(len(bn.weights), 4)
+ self.assertEqual(len(bn.trainable_weights), 2)
+ self.assertEqual(len(bn.non_trainable_weights), 2)
+
+ # Test that updates were created and added to UPDATE_OPS.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertListEqual(
+ tf.get_collection(tf.GraphKeys.UPDATE_OPS), bn.updates)
+
+ # Test that weights were created and added to TRAINABLE_VARIABLES.
+ self.assertListEqual(
+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
+ bn.trainable_weights)
+
+ def test3DInputAxis1(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=1,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 2))
+ std = np.std(np_inputs, axis=(0, 2))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 4, 1))
+ np_beta = np.reshape(np_beta, (1, 4, 1))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def test3DInputAxis2(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=2,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1))
+ std = np.std(np_inputs, axis=(0, 1))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 3))
+ np_beta = np.reshape(np_beta, (1, 1, 3))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def test4DInputAxis1(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=1,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 2, 3))
+ std = np.std(np_inputs, axis=(0, 2, 3))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
+ np_beta = np.reshape(np_beta, (1, 4, 1, 1))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def test4DInputAxis2(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=2,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1, 3))
+ std = np.std(np_inputs, axis=(0, 1, 3))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 3, 1))
+ np_beta = np.reshape(np_beta, (1, 1, 3, 1))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def test4DInputAxis3(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=3,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1, 2))
+ std = np.std(np_inputs, axis=(0, 1, 2))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def testNegativeAxis(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=-1,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + bn.updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1, 2))
+ std = np.std(np_inputs, axis=(0, 1, 2))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def testBooleanLearningPhase(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(axis=-1,
+ epsilon=epsilon, momentum=0.9)
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ outputs_training = bn.apply(inputs, training=True)
+ outputs_infer = bn.apply(inputs, training=False)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs_training] + bn.updates)
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1, 2))
+ std = np.std(np_inputs, axis=(0, 1, 2))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs_infer)
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def testFunctionalNoReuse(self):
+ inputs = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ epsilon = 1e-3
+ training = tf.placeholder(dtype='bool')
+ outputs = normalization_layers.batch_norm(
+ inputs, axis=-1, momentum=0.9, epsilon=epsilon,
+ training=training, name='bn')
+
+ updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ all_vars = dict([(v.name, v) for v in tf.global_variables()])
+ moving_mean = all_vars['bn/moving_mean:0']
+ moving_variance = all_vars['bn/moving_variance:0']
+ beta = all_vars['bn/beta:0']
+ gamma = all_vars['bn/gamma:0']
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs] + updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance])
+ np_inputs = sess.run(inputs)
+ np_mean = np.mean(np_inputs, axis=(0, 1, 2))
+ np_std = np.std(np_inputs, axis=(0, 1, 2))
+ np_variance = np.square(np_std)
+ self.assertAllClose(np_mean, np_moving_mean, atol=1e-2)
+ self.assertAllClose(np_variance, np_moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([gamma, beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def testFunctionalReuse(self):
+ inputs1 = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ inputs2 = tf.Variable(np.random.random((5, 4, 3, 6)), dtype=tf.float32)
+ epsilon = 1e-3
+ training = tf.placeholder(dtype='bool')
+ _ = normalization_layers.batch_norm(
+ inputs1, axis=-1, momentum=0.9, epsilon=epsilon,
+ training=training, name='bn')
+ outputs2 = normalization_layers.batch_norm(
+ inputs2, axis=-1, momentum=0.9, epsilon=epsilon,
+ training=training, name='bn', reuse=True)
+
+ # Last 2 update ops
+ updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)[-2:]
+ all_vars = dict([(v.name, v) for v in tf.global_variables()])
+ moving_mean = all_vars['bn/moving_mean:0']
+ moving_variance = all_vars['bn/moving_variance:0']
+ beta = all_vars['bn/beta:0']
+ gamma = all_vars['bn/gamma:0']
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(tf.global_variables_initializer())
+ for _ in range(100):
+ np_output, _, _ = sess.run([outputs2] + updates,
+ feed_dict={training: True})
+
+ # Verify that the statistics are updated during training.
+ np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance])
+ np_inputs = sess.run(inputs2)
+ np_mean = np.mean(np_inputs, axis=(0, 1, 2))
+ np_std = np.std(np_inputs, axis=(0, 1, 2))
+ np_variance = np.square(np_std)
+ self.assertAllClose(np_mean, np_moving_mean, atol=1e-2)
+ self.assertAllClose(np_variance, np_moving_var, atol=1e-2)
+
+ # Verify that the axis is normalized during training.
+ np_gamma, np_beta = sess.run([gamma, beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs2, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def testNoCenter(self):
+ bn = normalization_layers.BatchNormalization(axis=1, center=False)
+ inputs = tf.random_uniform((5, 4, 3), seed=1)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ # Verify shape.
+ self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])
+
+ # Verify layer attributes.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertEqual(len(bn.weights), 3)
+ self.assertEqual(len(bn.trainable_weights), 1)
+ self.assertEqual(len(bn.non_trainable_weights), 2)
+
+ def testNoScale(self):
+ bn = normalization_layers.BatchNormalization(axis=1, scale=False)
+ inputs = tf.random_uniform((5, 4, 3), seed=1)
+ training = tf.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ # Verify shape.
+ self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])
+
+ # Verify layer attributes.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertEqual(len(bn.weights), 3)
+ self.assertEqual(len(bn.trainable_weights), 1)
+ self.assertEqual(len(bn.non_trainable_weights), 2)
+
+ def testRegularizers(self):
+ reg = lambda x: 0.1 * tf.reduce_sum(x)
+ bn = normalization_layers.BatchNormalization(axis=1, beta_regularizer=reg)
+ inputs = tf.random_uniform((5, 4, 3), seed=1)
+ training = tf.placeholder(dtype='bool')
+ _ = bn.apply(inputs, training=training)
+ self.assertEqual(len(bn.losses), 1)
+
+ bn = normalization_layers.BatchNormalization(axis=1, gamma_regularizer=reg)
+ inputs = tf.random_uniform((5, 4, 3), seed=1)
+ training = tf.placeholder(dtype='bool')
+ _ = bn.apply(inputs, training=training)
+ self.assertEqual(len(bn.losses), 1)
+
+
+if __name__ == '__main__':
+ tf.test.main()