diff options
author | 2016-12-12 16:58:08 -0800 | |
---|---|---|
committer | 2016-12-12 17:05:21 -0800 | |
commit | 5d43d5531f8f1d6ff75b055df2096a4b2a2ae755 (patch) | |
tree | d52b0a2ac2696e83073fd3b632c56948eccd94bd | |
parent | db00a72915eee8d0271ae69748926fd7ede014fe (diff) |
Avoid both conditional branches being computed in Dropout and BatchNormalization.
Change: 141828986
-rw-r--r-- | tensorflow/python/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/python/layers/convolutional.py | 2 | ||||
-rw-r--r-- | tensorflow/python/layers/core.py | 24 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization.py | 141 | ||||
-rw-r--r-- | tensorflow/python/layers/pooling.py | 2 | ||||
-rw-r--r-- | tensorflow/python/layers/utils.py (renamed from tensorflow/python/layers/conv_utils.py) | 62 | ||||
-rw-r--r-- | tensorflow/python/layers/utils_test.py (renamed from tensorflow/python/layers/conv_utils_test.py) | 39 |
7 files changed, 159 insertions, 119 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 55330bee0d..a3e4848a11 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2334,12 +2334,12 @@ py_library( srcs = [ "layers/__init__.py", "layers/base.py", - "layers/conv_utils.py", "layers/convolutional.py", "layers/core.py", "layers/layers.py", "layers/normalization.py", "layers/pooling.py", + "layers/utils.py", ], srcs_version = "PY2AND3", deps = [ @@ -2414,12 +2414,12 @@ py_test( ) py_test( - name = "layers_conv_utils_test", + name = "layers_utils_test", size = "small", srcs = [ - "layers/conv_utils_test.py", + "layers/utils_test.py", ], - main = "layers/conv_utils_test.py", + main = "layers/utils_test.py", srcs_version = "PY2AND3", deps = [ ":client_testlib", diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 6caa3badfc..eb9f7eec98 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.layers import base -from tensorflow.python.layers import conv_utils as utils +from tensorflow.python.layers import utils class _Conv(base._Layer): # pylint: disable=protected-access diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 00db1f1714..0a71ae5ea5 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -28,15 +28,14 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import control_flow_ops from tensorflow.python.layers import base +from tensorflow.python.layers import utils class Dense(base._Layer): # pylint: disable=protected-access @@ -247,20 +246,13 @@ class Dropout(base._Layer): # pylint: disable=protected-access self.seed = seed def call(self, inputs, training=False): - if isinstance(training, bool): - training_bool = training - else: - training_bool = tensor_util.constant_value(training) - if training_bool is False: - return array_ops.identity(inputs) - dropped_inputs = nn.dropout(inputs, 1 - self.rate, - noise_shape=self.noise_shape, - seed=self.seed) - if training_bool is True: - return dropped_inputs - return control_flow_ops.cond(training, - lambda: dropped_inputs, - lambda: inputs) + def dropped_inputs(): + return nn.dropout(inputs, 1 - self.rate, + noise_shape=self.noise_shape, + seed=self.seed) + return utils.smart_cond(training, + dropped_inputs, + lambda: array_ops.identity(inputs)) def dropout(inputs, diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index ea43510f38..1daf765ab9 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import math_ops from tensorflow.python.ops import init_ops @@ -36,8 +35,10 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import moving_averages from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import variables from tensorflow.python.layers import base +from tensorflow.python.layers import utils class BatchNormalization(base._Layer): # pylint: disable=protected-access @@ -162,18 +163,21 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) - # Determine boolean training boolean value. May be False, True, None. - # If None, it is assumed that `training` is a variable to be used in `cond`. - if isinstance(training, bool): - training_bool = training - else: - try: - training_bool = tensor_util.constant_value(training) - except TypeError: - training_bool = None + # Determine a boolean value for `training`: could be True, False, or None. + training_value = utils.constant_value(training) - # Obtain current current batch mean, variance, if necessary. - if training_bool is not False: + if needs_broadcasting: + # In this case we must explictly broadcast all parameters. + if self.center: + broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) + else: + broadcast_beta = None + if self.scale: + broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) + else: + broadcast_gamma = None + + if training_value is not False: # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(self.moving_mean, 0) if needs_broadcasting: @@ -185,72 +189,59 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access else: mean, variance = nn.moments(inputs, reduction_axes, shift=shift) - # Prepare updates if necessary. - if training_bool is not False and not self.updates: - mean_update = moving_averages.assign_moving_average( - self.moving_mean, mean, self.momentum, zero_debias=False) - variance_update = moving_averages.assign_moving_average( - self.moving_variance, variance, self.momentum, zero_debias=False) - # In the future this should be refactored into a self.add_update - # methods in order to allow for instance-based BN layer sharing - # across unrelated input streams (e.g. like in Keras). - self.updates.append(mean_update) - self.updates.append(variance_update) - - # Normalize batch. - if needs_broadcasting: - # In this case we must explictly broadcast all parameters. - broadcast_moving_mean = array_ops.reshape(self.moving_mean, - broadcast_shape) - broadcast_moving_variance = array_ops.reshape(self.moving_variance, - broadcast_shape) - if self.center: - broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) + # Prepare updates if necessary. + if not self.updates: + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, self.momentum, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, self.momentum, zero_debias=False) + # In the future this should be refactored into a self.add_update + # methods in order to allow for instance-based BN layer sharing + # across unrelated input streams (e.g. like in Keras). + self.updates.append(mean_update) + self.updates.append(variance_update) + + # Normalize batch. We do this inside separate functions for training + # and inference so as to avoid evaluating both branches. + def normalize_in_test(): + if needs_broadcasting: + broadcast_moving_mean = array_ops.reshape(self.moving_mean, + broadcast_shape) + broadcast_moving_variance = array_ops.reshape(self.moving_variance, + broadcast_shape) + return nn.batch_normalization(inputs, + broadcast_moving_mean, + broadcast_moving_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) else: - broadcast_beta = None - if self.scale: - broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) + return nn.batch_normalization(inputs, + self.moving_mean, + self.moving_variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) + + def normalize_in_training(): + if needs_broadcasting: + return nn.batch_normalization(inputs, + broadcast_mean, + broadcast_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) else: - broadcast_gamma = None + return nn.batch_normalization(inputs, + mean, + variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) - if training_bool is not False: - normed_inputs_training = nn.batch_normalization(inputs, - broadcast_mean, - broadcast_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - normed_inputs = nn.batch_normalization(inputs, - broadcast_moving_mean, - broadcast_moving_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - else: - # No need for broadcasting. - if training_bool is not False: - normed_inputs_training = nn.batch_normalization( - inputs, - mean, - variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) - normed_inputs = nn.batch_normalization(inputs, - self.moving_mean, - self.moving_variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) - - # Return the proper output depending on the boolean training phase. - if training_bool is True: - return normed_inputs_training - if training_bool is False: - return normed_inputs - return control_flow_ops.cond(training, - lambda: normed_inputs_training, - lambda: normed_inputs) + return utils.smart_cond(training, + normalize_in_training, + normalize_in_test) def batch_normalization(inputs, diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index f6e8ce8a28..2601c61c47 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.layers import base -from tensorflow.python.layers import conv_utils as utils +from tensorflow.python.layers import utils class _Pooling1D(base._Layer): # pylint: disable=protected-access diff --git a/tensorflow/python/layers/conv_utils.py b/tensorflow/python/layers/utils.py index ad6e7d3f32..650e0586c3 100644 --- a/tensorflow/python/layers/conv_utils.py +++ b/tensorflow/python/layers/utils.py @@ -24,6 +24,11 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np +from tensorflow.python.ops import variables +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util + def convert_data_format(data_format, ndim): if data_format == 'channels_last': @@ -103,3 +108,60 @@ def normalize_padding(value): raise ValueError('The `padding` argument must be one of "valid", "same". ' 'Received: ' + str(padding)) return padding + + +def smart_cond(pred, fn1, fn2, name=None): + """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. + + If `pred` is a bool or has a constant value, we return either `fn1()` + or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `fn1` or `fn2`. + fn1: The callable to be performed if pred is true. + fn2: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `fn1` or `fn2`. + + Raises: + TypeError is fn1 or fn2 is not callable. + """ + if not callable(fn1): + raise TypeError('`fn1` must be callable.') + if not callable(fn2): + raise TypeError('`fn2` must be callable.') + + pred_value = constant_value(pred) + if pred_value is not None: + if pred_value: + return fn1() + else: + return fn2() + else: + return control_flow_ops.cond(pred, fn1, fn2, name) + + +def constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError is pred is not a Variable, Tensor or bool. + """ + if isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, variables.Variable): + pred_value = None + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + else: + raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') + return pred_value diff --git a/tensorflow/python/layers/conv_utils_test.py b/tensorflow/python/layers/utils_test.py index 431be414c2..d95ea04d42 100644 --- a/tensorflow/python/layers/conv_utils_test.py +++ b/tensorflow/python/layers/utils_test.py @@ -18,53 +18,48 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.layers import conv_utils +from tensorflow.python.layers import utils from tensorflow.python.platform import test class ConvUtilsTest(test.TestCase): def testConvertDataFormat(self): - self.assertEqual( - conv_utils.convert_data_format('channels_first', 4), 'NCHW') - self.assertEqual(conv_utils.convert_data_format('channels_first', 3), 'NCW') - self.assertEqual(conv_utils.convert_data_format('channels_last', 4), 'NHWC') - self.assertEqual(conv_utils.convert_data_format('channels_last', 3), 'NWC') - self.assertEqual( - conv_utils.convert_data_format('channels_last', 5), 'NDHWC') + self.assertEqual(utils.convert_data_format('channels_first', 4), 'NCHW') + self.assertEqual(utils.convert_data_format('channels_first', 3), 'NCW') + self.assertEqual(utils.convert_data_format('channels_last', 4), 'NHWC') + self.assertEqual(utils.convert_data_format('channels_last', 3), 'NWC') + self.assertEqual(utils.convert_data_format('channels_last', 5), 'NDHWC') with self.assertRaises(ValueError): - conv_utils.convert_data_format('invalid', 2) + utils.convert_data_format('invalid', 2) def testNormalizeTuple(self): + self.assertEqual(utils.normalize_tuple(2, n=3, name='strides'), (2, 2, 2)) self.assertEqual( - conv_utils.normalize_tuple( - 2, n=3, name='strides'), (2, 2, 2)) - self.assertEqual( - conv_utils.normalize_tuple( - (2, 1, 2), n=3, name='strides'), (2, 1, 2)) + utils.normalize_tuple((2, 1, 2), n=3, name='strides'), (2, 1, 2)) with self.assertRaises(ValueError): - conv_utils.normalize_tuple((2, 1), n=3, name='strides') + utils.normalize_tuple((2, 1), n=3, name='strides') with self.assertRaises(ValueError): - conv_utils.normalize_tuple(None, n=3, name='strides') + utils.normalize_tuple(None, n=3, name='strides') def testNormalizeDataFormat(self): self.assertEqual( - conv_utils.normalize_data_format('Channels_Last'), 'channels_last') + utils.normalize_data_format('Channels_Last'), 'channels_last') self.assertEqual( - conv_utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first') + utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first') with self.assertRaises(ValueError): - conv_utils.normalize_data_format('invalid') + utils.normalize_data_format('invalid') def testNormalizePadding(self): - self.assertEqual(conv_utils.normalize_padding('SAME'), 'same') - self.assertEqual(conv_utils.normalize_padding('VALID'), 'valid') + self.assertEqual(utils.normalize_padding('SAME'), 'same') + self.assertEqual(utils.normalize_padding('VALID'), 'valid') with self.assertRaises(ValueError): - conv_utils.normalize_padding('invalid') + utils.normalize_padding('invalid') if __name__ == '__main__': |