aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD7
-rw-r--r--tensorflow/python/layers/normalization.py88
-rw-r--r--tensorflow/python/layers/normalization_test.py81
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.pbtxt2
4 files changed, 168 insertions, 10 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 93606ce4ce..dcce808e97 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3555,13 +3555,11 @@ py_test(
],
)
-py_test(
+cuda_py_test(
name = "layers_normalization_test",
size = "small",
srcs = ["layers/normalization_test.py"],
- main = "layers/normalization_test.py",
- srcs_version = "PY2AND3",
- deps = [
+ additional_deps = [
":array_ops",
":client_testlib",
":framework_for_generated_wrappers",
@@ -3571,6 +3569,7 @@ py_test(
":variables",
"//third_party/py/numpy",
],
+ main = "layers/normalization_test.py",
)
# -----------------------------------------------------------------------------
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index ea6f55281e..780d1c2b8e 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -66,9 +66,6 @@ class BatchNormalization(base.Layer):
moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
- name: A string, the name of the layer.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
@@ -82,6 +79,11 @@ class BatchNormalization(base.Layer):
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
+ fused: if `True`, use a faster, fused implementation based on
+ nn.fused_batch_norm. If `None`, use the fused implementation if possible.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ name: A string, the name of the layer.
"""
def __init__(self,
@@ -99,6 +101,7 @@ class BatchNormalization(base.Layer):
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
+ fused=False,
trainable=True,
name=None,
**kwargs):
@@ -116,6 +119,10 @@ class BatchNormalization(base.Layer):
self.beta_regularizer = beta_regularizer
self.gamma_regularizer = gamma_regularizer
self.renorm = renorm
+ self.fused = fused
+ if self.fused and renorm:
+ raise ValueError(
+ 'Batch renorm is currently not supported with fused batch norm.')
if renorm:
renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax']
@@ -130,6 +137,13 @@ class BatchNormalization(base.Layer):
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndim = len(input_shape)
+ # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
+ # output back to its original shape accordingly.
+ if self.fused and ndim != 4:
+ raise ValueError(
+ 'Only 4D inputs are currently supported with fused batch norm. '
+ 'Consider reshaping the input to 4D and reshape the output back '
+ 'to its original shape. Got input rank: ', ndim)
if self.axis < 0:
axis = ndim + self.axis
else:
@@ -137,6 +151,20 @@ class BatchNormalization(base.Layer):
if axis < 0 or axis >= ndim:
raise ValueError('Value of `axis` argument ' + str(self.axis) +
' is out of range for input with rank ' + str(ndim))
+
+ if self.fused is None:
+ self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
+
+ if self.fused:
+ if axis == 1:
+ self._data_format = 'NCHW'
+ elif axis == 3:
+ self._data_format = 'NHWC'
+ else:
+ raise ValueError(
+ 'Only axis 1 and 3 are currently supported dimensions for '
+ 'fused batch norm. Got `axis` dimension: ', axis)
+
param_dim = input_shape[axis]
if not param_dim.value:
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
@@ -152,6 +180,8 @@ class BatchNormalization(base.Layer):
trainable=True)
else:
self.beta = None
+ if self.fused:
+ self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
if self.scale:
self.gamma = self.add_variable(name='gamma',
shape=(param_dim,),
@@ -160,6 +190,8 @@ class BatchNormalization(base.Layer):
trainable=True)
else:
self.gamma = None
+ if self.fused:
+ self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
# Disable variable partitioning when creating the moving mean and variance
partitioner = self._scope.partitioner
@@ -205,6 +237,45 @@ class BatchNormalization(base.Layer):
self._scope.set_partitioner(partitioner)
self.built = True
+ def _fused_batch_norm(self, inputs, training):
+ """Returns the output of fused batch norm."""
+ beta = self.beta if self.center else self._beta_const
+ gamma = self.gamma if self.scale else self._gamma_const
+
+ def _fused_batch_norm_training():
+ return nn.fused_batch_norm(
+ inputs,
+ gamma,
+ beta,
+ epsilon=self.epsilon,
+ data_format=self._data_format)
+
+ def _fused_batch_norm_inference():
+ return nn.fused_batch_norm(
+ inputs,
+ gamma,
+ beta,
+ mean=self.moving_mean,
+ variance=self.moving_variance,
+ epsilon=self.epsilon,
+ is_training=False,
+ data_format=self._data_format)
+
+ output, mean, variance = utils.smart_cond(
+ training, _fused_batch_norm_training, _fused_batch_norm_inference)
+
+ training_value = utils.constant_value(training)
+ if training_value is not False:
+ decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
+ mean_update = moving_averages.assign_moving_average(
+ self.moving_mean, mean, decay, zero_debias=False)
+ variance_update = moving_averages.assign_moving_average(
+ self.moving_variance, variance, decay, zero_debias=False)
+ self.add_update(mean_update, inputs=inputs)
+ self.add_update(variance_update, inputs=inputs)
+
+ return output
+
def _renorm_correction_and_moments(self, mean, variance, training):
"""Returns the correction and update values for renorm."""
stddev = math_ops.sqrt(variance + self.epsilon)
@@ -265,6 +336,9 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
+ if self.fused:
+ return self._fused_batch_norm(inputs, training=training)
+
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape()
@@ -353,7 +427,8 @@ def batch_normalization(inputs,
reuse=None,
renorm=False,
renorm_clipping=None,
- renorm_momentum=0.99):
+ renorm_momentum=0.99,
+ fused=False):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
@@ -415,6 +490,8 @@ def batch_normalization(inputs,
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
+ fused: if `True`, use a faster, fused implementation based on
+ nn.fused_batch_norm. If `None`, use the fused implementation if possible.
Returns:
Output tensor.
@@ -431,10 +508,11 @@ def batch_normalization(inputs,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
- trainable=trainable,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
+ fused=fused,
+ trainable=trainable,
name=name,
_reuse=reuse,
_scope=name)
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 933f196e01..fa6c9c4a5d 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -262,6 +262,87 @@ class BNTest(test.TestCase):
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+ def test4DInputAxis3Fused(self):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(
+ axis=3, epsilon=epsilon, momentum=0.9, fused=True)
+ inputs = variables.Variable(
+ np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
+ training = array_ops.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(variables.global_variables_initializer())
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
+ np_beta = np.reshape(np_beta, (1, 1, 1, 6))
+ for _ in range(100):
+ np_output, _, _ = sess.run(
+ [outputs] + bn.updates, feed_dict={training: True})
+ # Verify that the axis is normalized during training.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 1, 2))
+ std = np.std(np_inputs, axis=(0, 1, 2))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ def test4DInputAxis1Fused(self):
+ if test.is_gpu_available(cuda_only=True):
+ epsilon = 1e-3
+ bn = normalization_layers.BatchNormalization(
+ axis=1, epsilon=epsilon, momentum=0.9, fused=True)
+ inputs = variables.Variable(
+ np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
+ training = array_ops.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ with self.test_session() as sess:
+ # Test training with placeholder learning phase.
+ sess.run(variables.global_variables_initializer())
+ np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
+ np_beta = np.reshape(np_beta, (1, 4, 1, 1))
+ for _ in range(100):
+ np_output, _, _ = sess.run(
+ [outputs] + bn.updates, feed_dict={training: True})
+ # Verify that the axis is normalized during training.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
+ # Verify that the statistics are updated during training.
+ moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ np_inputs = sess.run(inputs)
+ mean = np.mean(np_inputs, axis=(0, 2, 3))
+ std = np.std(np_inputs, axis=(0, 2, 3))
+ variance = np.square(std)
+ self.assertAllClose(mean, moving_mean, atol=1e-2)
+ self.assertAllClose(variance, moving_var, atol=1e-2)
+
+ # Test inference with placeholder learning phase.
+ np_output = sess.run(outputs, feed_dict={training: False})
+
+ # Verify that the axis is normalized during inference.
+ normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
+ self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
+ self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
+
def testNegativeAxis(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
index 78b10c44a2..418ca3ea46 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
@@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], "
+ argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], "
}
member_method {
name: "conv1d"