aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-08-22 21:19:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-22 21:23:48 -0700
commitd2cf393807cb19ed7ed36e9036bf959c7f142090 (patch)
treee2bfaa1e502c281cec81b11f539069858d8a142d /tensorflow/contrib/layers
parentf4d7cddf421a646f6a62efb5da13adf7314cbd98 (diff)
Add an environment variable to test fused batch norm before we enable it by default.
PiperOrigin-RevId: 166155395
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py46
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py7
2 files changed, 22 insertions, 31 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index dffae60511..c73ec9fd57 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import os
import six
from tensorflow.contrib.framework.python.ops import add_arg_scope
@@ -97,6 +98,8 @@ DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
DATA_FORMAT_NCDHW = 'NCDHW'
DATA_FORMAT_NDHWC = 'NDHWC'
+_FUSED_DEFAULT = os.getenv('TF_DEFAULT_USES_FUSED_BATCH_NORM',
+ '').lower() in ('true', 't', '1')
@add_arg_scope
@@ -448,7 +451,7 @@ def batch_norm(inputs,
outputs_collections=None,
trainable=True,
batch_weights=None,
- fused=False,
+ fused=None,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
@@ -518,8 +521,8 @@ def batch_norm(inputs,
then the batch normalization uses weighted mean and
variance. (This can be used to correct for bias in training
example selection.)
- fused: if `True`, use a faster, fused implementation based on
- nn.fused_batch_norm. If `None`, use the fused implementation if possible.
+ fused: if `True`, use a faster, fused implementation if possible.
+ If `None`, use the system recommended implementation.
data_format: A string. `NHWC` (default) and `NCHW` are supported.
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
@@ -542,33 +545,28 @@ def batch_norm(inputs,
A `Tensor` representing the output of the operation.
Raises:
- ValueError: If `batch_weights` is not None and `fused` is True.
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If the rank of `inputs` is undefined.
ValueError: If rank or channels dimension of `inputs` is undefined.
"""
- if fused:
- if batch_weights is not None:
- raise ValueError('Weighted mean and variance is not currently '
- 'supported for fused batch norm.')
- if param_regularizers is not None:
- raise ValueError('Regularizers are not currently '
- 'supported for fused batch norm.')
- if renorm:
- raise ValueError('Renorm is not supported for fused batch norm.')
-
- # Only use _fused_batch_norm (1) if fused is set True or if it is
- # possible to use (currently it doesn't support batch weights,
- # renorm, and the case when rank is neither 2 nor 4),
- # and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
- # or non-default updates_collections (not implemented in
- # normalization_layers.BatchNormalization yet); otherwise use the fused
- # implementation in normalization_layers.BatchNormalization.
+ # This environment variable is only used during the testing period of fused
+ # batch norm and will be removed after that.
+ if fused is None:
+ fused = _FUSED_DEFAULT
+
+ # Only use _fused_batch_norm if all of the following three
+ # conditions are true:
+ # (1) fused is set True;
+ # (2) it is possible to use (currently it doesn't support batch weights,
+ # renorm, and the case when rank is neither 2 nor 4);
+ # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2,
+ # or non-default updates_collections (not implemented in
+ # normalization_layers.BatchNormalization yet); otherwise use the fused
+ # implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
- feature_supported = batch_weights is None and not renorm and rank in [2, 4]
- possible_to_fuse = fused is None and feature_supported
- if (fused or possible_to_fuse) and (
+ possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4]
+ if fused and possible_to_fuse and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
return _fused_batch_norm(
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 3b1af57d74..2b12990afd 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1766,13 +1766,6 @@ class BatchNormTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'undefined'):
_layers.batch_norm(inputs, data_format='NCHW')
- def testWeightedMomentsFused(self):
- with ops.Graph().as_default() as g, self.test_session(g):
- inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7))
- batch_weights = array_ops.placeholder(dtype=dtypes.float32)
- with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'):
- _layers.batch_norm(inputs, batch_weights=batch_weights, fused=True)
-
def _testCreateOp(self, fused):
height, width = 3, 3
with self.test_session():