aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization.py
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-08-22 21:19:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-22 21:23:48 -0700
commitd2cf393807cb19ed7ed36e9036bf959c7f142090 (patch)
treee2bfaa1e502c281cec81b11f539069858d8a142d /tensorflow/python/layers/normalization.py
parentf4d7cddf421a646f6a62efb5da13adf7314cbd98 (diff)
Add an environment variable to test fused batch norm before we enable it by default.
PiperOrigin-RevId: 166155395
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r--tensorflow/python/layers/normalization.py45
1 files changed, 19 insertions, 26 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 151dad9524..d0aa018929 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
@@ -40,6 +41,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
+_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
+ '').lower() in ('true', 't', '1')
+
class BatchNormalization(base.Layer):
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@@ -87,8 +91,8 @@ class BatchNormalization(base.Layer):
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
- fused: if `True`, use a faster, fused implementation based on
- nn.fused_batch_norm. If `None`, use the fused implementation if possible.
+ fused: if `True`, use a faster, fused implementation if possible.
+ If `None`, use the system recommended implementation.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
@@ -111,7 +115,7 @@ class BatchNormalization(base.Layer):
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
- fused=False,
+ fused=None,
trainable=True,
name=None,
**kwargs):
@@ -131,15 +135,13 @@ class BatchNormalization(base.Layer):
self.beta_constraint = beta_constraint
self.gamma_constraint = gamma_constraint
self.renorm = renorm
+ # This environment variable is only used during the testing period of fused
+ # batch norm and will be removed after that.
+ if fused is None:
+ fused = _FUSED_DEFAULT
+
self.fused = fused
self._bessels_correction_test_only = True
- if self.fused and renorm:
- raise ValueError(
- 'Batch renorm is currently not supported with fused batch norm.')
- if self.fused and (beta_regularizer is not None or
- gamma_regularizer is not None):
- raise ValueError('Regularizers are not currently '
- 'supported for fused batch norm.')
if renorm:
renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax']
@@ -154,13 +156,6 @@ class BatchNormalization(base.Layer):
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndim = len(input_shape)
- # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
- # output back to its original shape accordingly.
- if self.fused and ndim != 4:
- raise ValueError(
- 'Only 4D inputs are currently supported with fused batch norm. '
- 'Consider reshaping the input to 4D and reshape the output back '
- 'to its original shape. Got input rank: ', ndim)
if self.axis < 0:
axis = ndim + self.axis
else:
@@ -169,10 +164,12 @@ class BatchNormalization(base.Layer):
raise ValueError('Value of `axis` argument ' + str(self.axis) +
' is out of range for input with rank ' + str(ndim))
- if self.fused is None:
+ if self.fused:
# Currently fused batch norm doesn't support renorm and beta/gamma
# regularizer; and only supports an input tensor of rank 4 and a channel
# dimension on axis 1 and 3.
+ # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
+ # output back to its original shape accordingly.
self.fused = not self.renorm and ndim == 4 and axis in [
1, 3
] and self.beta_regularizer is None and self.gamma_regularizer is None
@@ -180,12 +177,8 @@ class BatchNormalization(base.Layer):
if self.fused:
if axis == 1:
self._data_format = 'NCHW'
- elif axis == 3:
- self._data_format = 'NHWC'
else:
- raise ValueError(
- 'Only axis 1 and 3 are currently supported dimensions for '
- 'fused batch norm. Got `axis` dimension: ', axis)
+ self._data_format = 'NHWC'
param_dim = input_shape[axis]
if not param_dim.value:
@@ -462,7 +455,7 @@ def batch_normalization(inputs,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
- fused=False):
+ fused=None):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
@@ -532,8 +525,8 @@ def batch_normalization(inputs,
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
- fused: if `True`, use a faster, fused implementation based on
- nn.fused_batch_norm. If `None`, use the fused implementation if possible.
+ fused: if `True`, use a faster, fused implementation if possible.
+ If `None`, use the system recommended implementation.
Returns:
Output tensor.