diff options
author | 2017-08-22 21:19:45 -0700 | |
---|---|---|
committer | 2017-08-22 21:23:48 -0700 | |
commit | d2cf393807cb19ed7ed36e9036bf959c7f142090 (patch) | |
tree | e2bfaa1e502c281cec81b11f539069858d8a142d /tensorflow/contrib/layers | |
parent | f4d7cddf421a646f6a62efb5da13adf7314cbd98 (diff) |
Add an environment variable to test fused batch norm before we enable it by default.
PiperOrigin-RevId: 166155395
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 46 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 7 |
2 files changed, 22 insertions, 31 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index dffae60511..c73ec9fd57 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import functools +import os import six from tensorflow.contrib.framework.python.ops import add_arg_scope @@ -97,6 +98,8 @@ DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' DATA_FORMAT_NCDHW = 'NCDHW' DATA_FORMAT_NDHWC = 'NDHWC' +_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM', + '').lower() in ('true', 't', '1') @add_arg_scope @@ -448,7 +451,7 @@ def batch_norm(inputs, outputs_collections=None, trainable=True, batch_weights=None, - fused=False, + fused=None, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, @@ -518,8 +521,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: if `True`, use a faster, fused implementation based on - nn.fused_batch_norm. If `None`, use the fused implementation if possible. + fused: if `True`, use a faster, fused implementation if possible. + If `None`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -542,33 +545,28 @@ def batch_norm(inputs, A `Tensor` representing the output of the operation. Raises: - ValueError: If `batch_weights` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ - if fused: - if batch_weights is not None: - raise ValueError('Weighted mean and variance is not currently ' - 'supported for fused batch norm.') - if param_regularizers is not None: - raise ValueError('Regularizers are not currently ' - 'supported for fused batch norm.') - if renorm: - raise ValueError('Renorm is not supported for fused batch norm.') - - # Only use _fused_batch_norm (1) if fused is set True or if it is - # possible to use (currently it doesn't support batch weights, - # renorm, and the case when rank is neither 2 nor 4), - # and (2) if used with zero_debias_moving_mean, or an input shape of rank 2, - # or non-default updates_collections (not implemented in - # normalization_layers.BatchNormalization yet); otherwise use the fused - # implementation in normalization_layers.BatchNormalization. + # This environment variable is only used during the testing period of fused + # batch norm and will be removed after that. + if fused is None: + fused = _FUSED_DEFAULT + + # Only use _fused_batch_norm if all of the following three + # conditions are true: + # (1) fused is set True; + # (2) it is possible to use (currently it doesn't support batch weights, + # renorm, and the case when rank is neither 2 nor 4); + # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2, + # or non-default updates_collections (not implemented in + # normalization_layers.BatchNormalization yet); otherwise use the fused + # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - feature_supported = batch_weights is None and not renorm and rank in [2, 4] - possible_to_fuse = fused is None and feature_supported - if (fused or possible_to_fuse) and ( + possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4] + if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): return _fused_batch_norm( diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 3b1af57d74..2b12990afd 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1766,13 +1766,6 @@ class BatchNormTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'undefined'): _layers.batch_norm(inputs, data_format='NCHW') - def testWeightedMomentsFused(self): - with ops.Graph().as_default() as g, self.test_session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7)) - batch_weights = array_ops.placeholder(dtype=dtypes.float32) - with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'): - _layers.batch_norm(inputs, batch_weights=batch_weights, fused=True) - def _testCreateOp(self, fused): height, width = 3, 3 with self.test_session(): |