diff options
-rw-r--r-- | tensorflow/python/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization.py | 88 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 81 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.layers.pbtxt | 2 |
4 files changed, 168 insertions, 10 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 93606ce4ce..dcce808e97 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3555,13 +3555,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "layers_normalization_test", size = "small", srcs = ["layers/normalization_test.py"], - main = "layers/normalization_test.py", - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":array_ops", ":client_testlib", ":framework_for_generated_wrappers", @@ -3571,6 +3569,7 @@ py_test( ":variables", "//third_party/py/numpy", ], + main = "layers/normalization_test.py", ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index ea6f55281e..780d1c2b8e 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -66,9 +66,6 @@ class BatchNormalization(base.Layer): moving_variance_initializer: Initializer for the moving variance. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: A string, the name of the layer. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. @@ -82,6 +79,11 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: A string, the name of the layer. """ def __init__(self, @@ -99,6 +101,7 @@ class BatchNormalization(base.Layer): renorm=False, renorm_clipping=None, renorm_momentum=0.99, + fused=False, trainable=True, name=None, **kwargs): @@ -116,6 +119,10 @@ class BatchNormalization(base.Layer): self.beta_regularizer = beta_regularizer self.gamma_regularizer = gamma_regularizer self.renorm = renorm + self.fused = fused + if self.fused and renorm: + raise ValueError( + 'Batch renorm is currently not supported with fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -130,6 +137,13 @@ class BatchNormalization(base.Layer): if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndim = len(input_shape) + # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the + # output back to its original shape accordingly. + if self.fused and ndim != 4: + raise ValueError( + 'Only 4D inputs are currently supported with fused batch norm. ' + 'Consider reshaping the input to 4D and reshape the output back ' + 'to its original shape. Got input rank: ', ndim) if self.axis < 0: axis = ndim + self.axis else: @@ -137,6 +151,20 @@ class BatchNormalization(base.Layer): if axis < 0 or axis >= ndim: raise ValueError('Value of `axis` argument ' + str(self.axis) + ' is out of range for input with rank ' + str(ndim)) + + if self.fused is None: + self.fused = not self.renorm and ndim == 4 and axis in [1, 3] + + if self.fused: + if axis == 1: + self._data_format = 'NCHW' + elif axis == 3: + self._data_format = 'NHWC' + else: + raise ValueError( + 'Only axis 1 and 3 are currently supported dimensions for ' + 'fused batch norm. Got `axis` dimension: ', axis) + param_dim = input_shape[axis] if not param_dim.value: raise ValueError('Input has undefined `axis` dimension. Input shape: ', @@ -152,6 +180,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.beta = None + if self.fused: + self._beta_const = array_ops.constant(0.0, shape=(param_dim,)) if self.scale: self.gamma = self.add_variable(name='gamma', shape=(param_dim,), @@ -160,6 +190,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.gamma = None + if self.fused: + self._gamma_const = array_ops.constant(1.0, shape=(param_dim,)) # Disable variable partitioning when creating the moving mean and variance partitioner = self._scope.partitioner @@ -205,6 +237,45 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(partitioner) self.built = True + def _fused_batch_norm(self, inputs, training): + """Returns the output of fused batch norm.""" + beta = self.beta if self.center else self._beta_const + gamma = self.gamma if self.scale else self._gamma_const + + def _fused_batch_norm_training(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + epsilon=self.epsilon, + data_format=self._data_format) + + def _fused_batch_norm_inference(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + mean=self.moving_mean, + variance=self.moving_variance, + epsilon=self.epsilon, + is_training=False, + data_format=self._data_format) + + output, mean, variance = utils.smart_cond( + training, _fused_batch_norm_training, _fused_batch_norm_inference) + + training_value = utils.constant_value(training) + if training_value is not False: + decay = _smart_select(training, lambda: self.momentum, lambda: 1.) + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, decay, zero_debias=False) + self.add_update(mean_update, inputs=inputs) + self.add_update(variance_update, inputs=inputs) + + return output + def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) @@ -265,6 +336,9 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + if self.fused: + return self._fused_batch_norm(inputs, training=training) + # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() @@ -353,7 +427,8 @@ def batch_normalization(inputs, reuse=None, renorm=False, renorm_clipping=None, - renorm_momentum=0.99): + renorm_momentum=0.99, + fused=False): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -415,6 +490,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. Returns: Output tensor. @@ -431,10 +508,11 @@ def batch_normalization(inputs, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, - trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_momentum, + fused=fused, + trainable=trainable, name=name, _reuse=reuse, _scope=name) diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 933f196e01..fa6c9c4a5d 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -262,6 +262,87 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def test4DInputAxis3Fused(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=3, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis1Fused(self): + if test.is_gpu_available(cuda_only=True): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=1, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def testNegativeAxis(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt index 78b10c44a2..418ca3ea46 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "batch_normalization" - argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], " + argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], " } member_method { name: "conv1d" |