diff options
author | 2017-08-22 21:19:45 -0700 | |
---|---|---|
committer | 2017-08-22 21:23:48 -0700 | |
commit | d2cf393807cb19ed7ed36e9036bf959c7f142090 (patch) | |
tree | e2bfaa1e502c281cec81b11f539069858d8a142d /tensorflow/python/layers/normalization.py | |
parent | f4d7cddf421a646f6a62efb5da13adf7314cbd98 (diff) |
Add an environment variable to test fused batch norm before we enable it by default.
PiperOrigin-RevId: 166155395
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 45 |
1 files changed, 19 insertions, 26 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 151dad9524..d0aa018929 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -20,6 +20,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import six from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np @@ -40,6 +41,9 @@ from tensorflow.python.ops import variables from tensorflow.python.layers import base from tensorflow.python.layers import utils +_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM', + '').lower() in ('true', 't', '1') + class BatchNormalization(base.Layer): """Batch Normalization layer from http://arxiv.org/abs/1502.03167. @@ -87,8 +91,8 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation based on - nn.fused_batch_norm. If `None`, use the fused implementation if possible. + fused: if `True`, use a faster, fused implementation if possible. + If `None`, use the system recommended implementation. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). name: A string, the name of the layer. @@ -111,7 +115,7 @@ class BatchNormalization(base.Layer): renorm=False, renorm_clipping=None, renorm_momentum=0.99, - fused=False, + fused=None, trainable=True, name=None, **kwargs): @@ -131,15 +135,13 @@ class BatchNormalization(base.Layer): self.beta_constraint = beta_constraint self.gamma_constraint = gamma_constraint self.renorm = renorm + # This environment variable is only used during the testing period of fused + # batch norm and will be removed after that. + if fused is None: + fused = _FUSED_DEFAULT + self.fused = fused self._bessels_correction_test_only = True - if self.fused and renorm: - raise ValueError( - 'Batch renorm is currently not supported with fused batch norm.') - if self.fused and (beta_regularizer is not None or - gamma_regularizer is not None): - raise ValueError('Regularizers are not currently ' - 'supported for fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -154,13 +156,6 @@ class BatchNormalization(base.Layer): if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndim = len(input_shape) - # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the - # output back to its original shape accordingly. - if self.fused and ndim != 4: - raise ValueError( - 'Only 4D inputs are currently supported with fused batch norm. ' - 'Consider reshaping the input to 4D and reshape the output back ' - 'to its original shape. Got input rank: ', ndim) if self.axis < 0: axis = ndim + self.axis else: @@ -169,10 +164,12 @@ class BatchNormalization(base.Layer): raise ValueError('Value of `axis` argument ' + str(self.axis) + ' is out of range for input with rank ' + str(ndim)) - if self.fused is None: + if self.fused: # Currently fused batch norm doesn't support renorm and beta/gamma # regularizer; and only supports an input tensor of rank 4 and a channel # dimension on axis 1 and 3. + # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the + # output back to its original shape accordingly. self.fused = not self.renorm and ndim == 4 and axis in [ 1, 3 ] and self.beta_regularizer is None and self.gamma_regularizer is None @@ -180,12 +177,8 @@ class BatchNormalization(base.Layer): if self.fused: if axis == 1: self._data_format = 'NCHW' - elif axis == 3: - self._data_format = 'NHWC' else: - raise ValueError( - 'Only axis 1 and 3 are currently supported dimensions for ' - 'fused batch norm. Got `axis` dimension: ', axis) + self._data_format = 'NHWC' param_dim = input_shape[axis] if not param_dim.value: @@ -462,7 +455,7 @@ def batch_normalization(inputs, renorm=False, renorm_clipping=None, renorm_momentum=0.99, - fused=False): + fused=None): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -532,8 +525,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation based on - nn.fused_batch_norm. If `None`, use the fused implementation if possible. + fused: if `True`, use a faster, fused implementation if possible. + If `None`, use the system recommended implementation. Returns: Output tensor. |