aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2016-12-12 16:58:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-12 17:05:21 -0800
commit5d43d5531f8f1d6ff75b055df2096a4b2a2ae755 (patch)
treed52b0a2ac2696e83073fd3b632c56948eccd94bd
parentdb00a72915eee8d0271ae69748926fd7ede014fe (diff)
Avoid both conditional branches being computed in Dropout and BatchNormalization.
Change: 141828986
-rw-r--r--tensorflow/python/BUILD8
-rw-r--r--tensorflow/python/layers/convolutional.py2
-rw-r--r--tensorflow/python/layers/core.py24
-rw-r--r--tensorflow/python/layers/normalization.py141
-rw-r--r--tensorflow/python/layers/pooling.py2
-rw-r--r--tensorflow/python/layers/utils.py (renamed from tensorflow/python/layers/conv_utils.py)62
-rw-r--r--tensorflow/python/layers/utils_test.py (renamed from tensorflow/python/layers/conv_utils_test.py)39
7 files changed, 159 insertions, 119 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 55330bee0d..a3e4848a11 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2334,12 +2334,12 @@ py_library(
srcs = [
"layers/__init__.py",
"layers/base.py",
- "layers/conv_utils.py",
"layers/convolutional.py",
"layers/core.py",
"layers/layers.py",
"layers/normalization.py",
"layers/pooling.py",
+ "layers/utils.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -2414,12 +2414,12 @@ py_test(
)
py_test(
- name = "layers_conv_utils_test",
+ name = "layers_utils_test",
size = "small",
srcs = [
- "layers/conv_utils_test.py",
+ "layers/utils_test.py",
],
- main = "layers/conv_utils_test.py",
+ main = "layers/utils_test.py",
srcs_version = "PY2AND3",
deps = [
":client_testlib",
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 6caa3badfc..eb9f7eec98 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.layers import base
-from tensorflow.python.layers import conv_utils as utils
+from tensorflow.python.layers import utils
class _Conv(base._Layer): # pylint: disable=protected-access
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 00db1f1714..0a71ae5ea5 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -28,15 +28,14 @@ import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.layers import base
+from tensorflow.python.layers import utils
class Dense(base._Layer): # pylint: disable=protected-access
@@ -247,20 +246,13 @@ class Dropout(base._Layer): # pylint: disable=protected-access
self.seed = seed
def call(self, inputs, training=False):
- if isinstance(training, bool):
- training_bool = training
- else:
- training_bool = tensor_util.constant_value(training)
- if training_bool is False:
- return array_ops.identity(inputs)
- dropped_inputs = nn.dropout(inputs, 1 - self.rate,
- noise_shape=self.noise_shape,
- seed=self.seed)
- if training_bool is True:
- return dropped_inputs
- return control_flow_ops.cond(training,
- lambda: dropped_inputs,
- lambda: inputs)
+ def dropped_inputs():
+ return nn.dropout(inputs, 1 - self.rate,
+ noise_shape=self.noise_shape,
+ seed=self.seed)
+ return utils.smart_cond(training,
+ dropped_inputs,
+ lambda: array_ops.identity(inputs))
def dropout(inputs,
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index ea43510f38..1daf765ab9 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import init_ops
@@ -36,8 +35,10 @@ from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import moving_averages
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import variables
from tensorflow.python.layers import base
+from tensorflow.python.layers import utils
class BatchNormalization(base._Layer): # pylint: disable=protected-access
@@ -162,18 +163,21 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
- # Determine boolean training boolean value. May be False, True, None.
- # If None, it is assumed that `training` is a variable to be used in `cond`.
- if isinstance(training, bool):
- training_bool = training
- else:
- try:
- training_bool = tensor_util.constant_value(training)
- except TypeError:
- training_bool = None
+ # Determine a boolean value for `training`: could be True, False, or None.
+ training_value = utils.constant_value(training)
- # Obtain current current batch mean, variance, if necessary.
- if training_bool is not False:
+ if needs_broadcasting:
+ # In this case we must explictly broadcast all parameters.
+ if self.center:
+ broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
+ else:
+ broadcast_beta = None
+ if self.scale:
+ broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
+ else:
+ broadcast_gamma = None
+
+ if training_value is not False:
# Use a copy of moving_mean as a shift to compute more reliable moments.
shift = math_ops.add(self.moving_mean, 0)
if needs_broadcasting:
@@ -185,72 +189,59 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
else:
mean, variance = nn.moments(inputs, reduction_axes, shift=shift)
- # Prepare updates if necessary.
- if training_bool is not False and not self.updates:
- mean_update = moving_averages.assign_moving_average(
- self.moving_mean, mean, self.momentum, zero_debias=False)
- variance_update = moving_averages.assign_moving_average(
- self.moving_variance, variance, self.momentum, zero_debias=False)
- # In the future this should be refactored into a self.add_update
- # methods in order to allow for instance-based BN layer sharing
- # across unrelated input streams (e.g. like in Keras).
- self.updates.append(mean_update)
- self.updates.append(variance_update)
-
- # Normalize batch.
- if needs_broadcasting:
- # In this case we must explictly broadcast all parameters.
- broadcast_moving_mean = array_ops.reshape(self.moving_mean,
- broadcast_shape)
- broadcast_moving_variance = array_ops.reshape(self.moving_variance,
- broadcast_shape)
- if self.center:
- broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
+ # Prepare updates if necessary.
+ if not self.updates:
+ mean_update = moving_averages.assign_moving_average(
+ self.moving_mean, mean, self.momentum, zero_debias=False)
+ variance_update = moving_averages.assign_moving_average(
+ self.moving_variance, variance, self.momentum, zero_debias=False)
+ # In the future this should be refactored into a self.add_update
+ # methods in order to allow for instance-based BN layer sharing
+ # across unrelated input streams (e.g. like in Keras).
+ self.updates.append(mean_update)
+ self.updates.append(variance_update)
+
+ # Normalize batch. We do this inside separate functions for training
+ # and inference so as to avoid evaluating both branches.
+ def normalize_in_test():
+ if needs_broadcasting:
+ broadcast_moving_mean = array_ops.reshape(self.moving_mean,
+ broadcast_shape)
+ broadcast_moving_variance = array_ops.reshape(self.moving_variance,
+ broadcast_shape)
+ return nn.batch_normalization(inputs,
+ broadcast_moving_mean,
+ broadcast_moving_variance,
+ broadcast_beta,
+ broadcast_gamma,
+ self.epsilon)
else:
- broadcast_beta = None
- if self.scale:
- broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
+ return nn.batch_normalization(inputs,
+ self.moving_mean,
+ self.moving_variance,
+ self.beta if self.center else None,
+ self.gamma if self.scale else None,
+ self.epsilon)
+
+ def normalize_in_training():
+ if needs_broadcasting:
+ return nn.batch_normalization(inputs,
+ broadcast_mean,
+ broadcast_variance,
+ broadcast_beta,
+ broadcast_gamma,
+ self.epsilon)
else:
- broadcast_gamma = None
+ return nn.batch_normalization(inputs,
+ mean,
+ variance,
+ self.beta if self.center else None,
+ self.gamma if self.scale else None,
+ self.epsilon)
- if training_bool is not False:
- normed_inputs_training = nn.batch_normalization(inputs,
- broadcast_mean,
- broadcast_variance,
- broadcast_beta,
- broadcast_gamma,
- self.epsilon)
- normed_inputs = nn.batch_normalization(inputs,
- broadcast_moving_mean,
- broadcast_moving_variance,
- broadcast_beta,
- broadcast_gamma,
- self.epsilon)
- else:
- # No need for broadcasting.
- if training_bool is not False:
- normed_inputs_training = nn.batch_normalization(
- inputs,
- mean,
- variance,
- self.beta if self.center else None,
- self.gamma if self.scale else None,
- self.epsilon)
- normed_inputs = nn.batch_normalization(inputs,
- self.moving_mean,
- self.moving_variance,
- self.beta if self.center else None,
- self.gamma if self.scale else None,
- self.epsilon)
-
- # Return the proper output depending on the boolean training phase.
- if training_bool is True:
- return normed_inputs_training
- if training_bool is False:
- return normed_inputs
- return control_flow_ops.cond(training,
- lambda: normed_inputs_training,
- lambda: normed_inputs)
+ return utils.smart_cond(training,
+ normalize_in_training,
+ normalize_in_test)
def batch_normalization(inputs,
diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py
index f6e8ce8a28..2601c61c47 100644
--- a/tensorflow/python/layers/pooling.py
+++ b/tensorflow/python/layers/pooling.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.layers import base
-from tensorflow.python.layers import conv_utils as utils
+from tensorflow.python.layers import utils
class _Pooling1D(base._Layer): # pylint: disable=protected-access
diff --git a/tensorflow/python/layers/conv_utils.py b/tensorflow/python/layers/utils.py
index ad6e7d3f32..650e0586c3 100644
--- a/tensorflow/python/layers/conv_utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -24,6 +24,11 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
+from tensorflow.python.ops import variables
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+
def convert_data_format(data_format, ndim):
if data_format == 'channels_last':
@@ -103,3 +108,60 @@ def normalize_padding(value):
raise ValueError('The `padding` argument must be one of "valid", "same". '
'Received: ' + str(padding))
return padding
+
+
+def smart_cond(pred, fn1, fn2, name=None):
+ """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
+
+ If `pred` is a bool or has a constant value, we return either `fn1()`
+ or `fn2()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `fn1` or `fn2`.
+ fn1: The callable to be performed if pred is true.
+ fn2: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `fn1` or `fn2`.
+
+ Raises:
+ TypeError is fn1 or fn2 is not callable.
+ """
+ if not callable(fn1):
+ raise TypeError('`fn1` must be callable.')
+ if not callable(fn2):
+ raise TypeError('`fn2` must be callable.')
+
+ pred_value = constant_value(pred)
+ if pred_value is not None:
+ if pred_value:
+ return fn1()
+ else:
+ return fn2()
+ else:
+ return control_flow_ops.cond(pred, fn1, fn2, name)
+
+
+def constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or a TensorFlow boolean variable
+ or tensor.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError is pred is not a Variable, Tensor or bool.
+ """
+ if isinstance(pred, bool):
+ pred_value = pred
+ elif isinstance(pred, variables.Variable):
+ pred_value = None
+ elif isinstance(pred, ops.Tensor):
+ pred_value = tensor_util.constant_value(pred)
+ else:
+ raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
+ return pred_value
diff --git a/tensorflow/python/layers/conv_utils_test.py b/tensorflow/python/layers/utils_test.py
index 431be414c2..d95ea04d42 100644
--- a/tensorflow/python/layers/conv_utils_test.py
+++ b/tensorflow/python/layers/utils_test.py
@@ -18,53 +18,48 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.layers import conv_utils
+from tensorflow.python.layers import utils
from tensorflow.python.platform import test
class ConvUtilsTest(test.TestCase):
def testConvertDataFormat(self):
- self.assertEqual(
- conv_utils.convert_data_format('channels_first', 4), 'NCHW')
- self.assertEqual(conv_utils.convert_data_format('channels_first', 3), 'NCW')
- self.assertEqual(conv_utils.convert_data_format('channels_last', 4), 'NHWC')
- self.assertEqual(conv_utils.convert_data_format('channels_last', 3), 'NWC')
- self.assertEqual(
- conv_utils.convert_data_format('channels_last', 5), 'NDHWC')
+ self.assertEqual(utils.convert_data_format('channels_first', 4), 'NCHW')
+ self.assertEqual(utils.convert_data_format('channels_first', 3), 'NCW')
+ self.assertEqual(utils.convert_data_format('channels_last', 4), 'NHWC')
+ self.assertEqual(utils.convert_data_format('channels_last', 3), 'NWC')
+ self.assertEqual(utils.convert_data_format('channels_last', 5), 'NDHWC')
with self.assertRaises(ValueError):
- conv_utils.convert_data_format('invalid', 2)
+ utils.convert_data_format('invalid', 2)
def testNormalizeTuple(self):
+ self.assertEqual(utils.normalize_tuple(2, n=3, name='strides'), (2, 2, 2))
self.assertEqual(
- conv_utils.normalize_tuple(
- 2, n=3, name='strides'), (2, 2, 2))
- self.assertEqual(
- conv_utils.normalize_tuple(
- (2, 1, 2), n=3, name='strides'), (2, 1, 2))
+ utils.normalize_tuple((2, 1, 2), n=3, name='strides'), (2, 1, 2))
with self.assertRaises(ValueError):
- conv_utils.normalize_tuple((2, 1), n=3, name='strides')
+ utils.normalize_tuple((2, 1), n=3, name='strides')
with self.assertRaises(ValueError):
- conv_utils.normalize_tuple(None, n=3, name='strides')
+ utils.normalize_tuple(None, n=3, name='strides')
def testNormalizeDataFormat(self):
self.assertEqual(
- conv_utils.normalize_data_format('Channels_Last'), 'channels_last')
+ utils.normalize_data_format('Channels_Last'), 'channels_last')
self.assertEqual(
- conv_utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first')
+ utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first')
with self.assertRaises(ValueError):
- conv_utils.normalize_data_format('invalid')
+ utils.normalize_data_format('invalid')
def testNormalizePadding(self):
- self.assertEqual(conv_utils.normalize_padding('SAME'), 'same')
- self.assertEqual(conv_utils.normalize_padding('VALID'), 'valid')
+ self.assertEqual(utils.normalize_padding('SAME'), 'same')
+ self.assertEqual(utils.normalize_padding('VALID'), 'valid')
with self.assertRaises(ValueError):
- conv_utils.normalize_padding('invalid')
+ utils.normalize_padding('invalid')
if __name__ == '__main__':