From 5d43d5531f8f1d6ff75b055df2096a4b2a2ae755 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 12 Dec 2016 16:58:08 -0800 Subject: Avoid both conditional branches being computed in Dropout and BatchNormalization. Change: 141828986 --- tensorflow/python/BUILD | 8 +- tensorflow/python/layers/conv_utils.py | 105 ----------------- tensorflow/python/layers/conv_utils_test.py | 71 ------------ tensorflow/python/layers/convolutional.py | 2 +- tensorflow/python/layers/core.py | 24 ++-- tensorflow/python/layers/normalization.py | 141 +++++++++++------------ tensorflow/python/layers/pooling.py | 2 +- tensorflow/python/layers/utils.py | 167 ++++++++++++++++++++++++++++ tensorflow/python/layers/utils_test.py | 66 +++++++++++ 9 files changed, 313 insertions(+), 273 deletions(-) delete mode 100644 tensorflow/python/layers/conv_utils.py delete mode 100644 tensorflow/python/layers/conv_utils_test.py create mode 100644 tensorflow/python/layers/utils.py create mode 100644 tensorflow/python/layers/utils_test.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 55330bee0d..a3e4848a11 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2334,12 +2334,12 @@ py_library( srcs = [ "layers/__init__.py", "layers/base.py", - "layers/conv_utils.py", "layers/convolutional.py", "layers/core.py", "layers/layers.py", "layers/normalization.py", "layers/pooling.py", + "layers/utils.py", ], srcs_version = "PY2AND3", deps = [ @@ -2414,12 +2414,12 @@ py_test( ) py_test( - name = "layers_conv_utils_test", + name = "layers_utils_test", size = "small", srcs = [ - "layers/conv_utils_test.py", + "layers/utils_test.py", ], - main = "layers/conv_utils_test.py", + main = "layers/utils_test.py", srcs_version = "PY2AND3", deps = [ ":client_testlib", diff --git a/tensorflow/python/layers/conv_utils.py b/tensorflow/python/layers/conv_utils.py deleted file mode 100644 index ad6e7d3f32..0000000000 --- a/tensorflow/python/layers/conv_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# pylint: disable=unused-import,g-bad-import-order -"""Contains layer utilies for input validation and format conversion. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import six -from six.moves import xrange # pylint: disable=redefined-builtin -import numpy as np - - -def convert_data_format(data_format, ndim): - if data_format == 'channels_last': - if ndim == 3: - return 'NWC' - elif ndim == 4: - return 'NHWC' - elif ndim == 5: - return 'NDHWC' - else: - raise ValueError('Input rank not supported:', ndim) - elif data_format == 'channels_first': - if ndim == 3: - return 'NCW' - elif ndim == 4: - return 'NCHW' - elif ndim == 5: - raise ValueError('Data format "channels_first" not supported for ' - 'inputs with rank 5.') - else: - raise ValueError('Input rank not supported:', ndim) - else: - raise ValueError('Invalid data_format:', data_format) - - -def normalize_tuple(value, n, name): - """Transforms a single integer or iterable of integers into an integer tuple. - - Arguments: - value: The value to validate and convert. Could an int, or any iterable - of ints. - n: The size of the tuple to be returned. - name: The name of the argument being validated, e.g. "strides" or - "kernel_size". This is only used to format error messages. - - Returns: - A tuple of n integers. - - Raises: - ValueError: If something else than an int/long or iterable thereof was - passed. - """ - if isinstance(value, int): - return (value,) * n - else: - try: - value_tuple = tuple(value) - except TypeError: - raise ValueError('The `' + name + '` argument must be a tuple of ' + - str(n) + ' integers. Received: ' + str(value)) - if len(value_tuple) != n: - raise ValueError('The `' + name + '` argument must be a tuple of ' + - str(n) + ' integers. Received: ' + str(value)) - for single_value in value_tuple: - try: - int(single_value) - except ValueError: - raise ValueError('The `' + name + '` argument must be a tuple of ' + - str(n) + ' integers. Received: ' + str(value) + ' ' - 'including element ' + str(single_value) + ' of type' + - ' ' + str(type(single_value))) - return value_tuple - - -def normalize_data_format(value): - data_format = value.lower() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('The `data_format` argument must be one of ' - '"channels_first", "channels_last". Received: ' + - str(value)) - return data_format - - -def normalize_padding(value): - padding = value.lower() - if padding not in {'valid', 'same'}: - raise ValueError('The `padding` argument must be one of "valid", "same". ' - 'Received: ' + str(padding)) - return padding diff --git a/tensorflow/python/layers/conv_utils_test.py b/tensorflow/python/layers/conv_utils_test.py deleted file mode 100644 index 431be414c2..0000000000 --- a/tensorflow/python/layers/conv_utils_test.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.layers.core.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.layers import conv_utils -from tensorflow.python.platform import test - - -class ConvUtilsTest(test.TestCase): - - def testConvertDataFormat(self): - self.assertEqual( - conv_utils.convert_data_format('channels_first', 4), 'NCHW') - self.assertEqual(conv_utils.convert_data_format('channels_first', 3), 'NCW') - self.assertEqual(conv_utils.convert_data_format('channels_last', 4), 'NHWC') - self.assertEqual(conv_utils.convert_data_format('channels_last', 3), 'NWC') - self.assertEqual( - conv_utils.convert_data_format('channels_last', 5), 'NDHWC') - - with self.assertRaises(ValueError): - conv_utils.convert_data_format('invalid', 2) - - def testNormalizeTuple(self): - self.assertEqual( - conv_utils.normalize_tuple( - 2, n=3, name='strides'), (2, 2, 2)) - self.assertEqual( - conv_utils.normalize_tuple( - (2, 1, 2), n=3, name='strides'), (2, 1, 2)) - - with self.assertRaises(ValueError): - conv_utils.normalize_tuple((2, 1), n=3, name='strides') - - with self.assertRaises(ValueError): - conv_utils.normalize_tuple(None, n=3, name='strides') - - def testNormalizeDataFormat(self): - self.assertEqual( - conv_utils.normalize_data_format('Channels_Last'), 'channels_last') - self.assertEqual( - conv_utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first') - - with self.assertRaises(ValueError): - conv_utils.normalize_data_format('invalid') - - def testNormalizePadding(self): - self.assertEqual(conv_utils.normalize_padding('SAME'), 'same') - self.assertEqual(conv_utils.normalize_padding('VALID'), 'valid') - - with self.assertRaises(ValueError): - conv_utils.normalize_padding('invalid') - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 6caa3badfc..eb9f7eec98 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.layers import base -from tensorflow.python.layers import conv_utils as utils +from tensorflow.python.layers import utils class _Conv(base._Layer): # pylint: disable=protected-access diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 00db1f1714..0a71ae5ea5 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -28,15 +28,14 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import control_flow_ops from tensorflow.python.layers import base +from tensorflow.python.layers import utils class Dense(base._Layer): # pylint: disable=protected-access @@ -247,20 +246,13 @@ class Dropout(base._Layer): # pylint: disable=protected-access self.seed = seed def call(self, inputs, training=False): - if isinstance(training, bool): - training_bool = training - else: - training_bool = tensor_util.constant_value(training) - if training_bool is False: - return array_ops.identity(inputs) - dropped_inputs = nn.dropout(inputs, 1 - self.rate, - noise_shape=self.noise_shape, - seed=self.seed) - if training_bool is True: - return dropped_inputs - return control_flow_ops.cond(training, - lambda: dropped_inputs, - lambda: inputs) + def dropped_inputs(): + return nn.dropout(inputs, 1 - self.rate, + noise_shape=self.noise_shape, + seed=self.seed) + return utils.smart_cond(training, + dropped_inputs, + lambda: array_ops.identity(inputs)) def dropout(inputs, diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index ea43510f38..1daf765ab9 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import math_ops from tensorflow.python.ops import init_ops @@ -36,8 +35,10 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import moving_averages from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import variables from tensorflow.python.layers import base +from tensorflow.python.layers import utils class BatchNormalization(base._Layer): # pylint: disable=protected-access @@ -162,18 +163,21 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) - # Determine boolean training boolean value. May be False, True, None. - # If None, it is assumed that `training` is a variable to be used in `cond`. - if isinstance(training, bool): - training_bool = training - else: - try: - training_bool = tensor_util.constant_value(training) - except TypeError: - training_bool = None + # Determine a boolean value for `training`: could be True, False, or None. + training_value = utils.constant_value(training) - # Obtain current current batch mean, variance, if necessary. - if training_bool is not False: + if needs_broadcasting: + # In this case we must explictly broadcast all parameters. + if self.center: + broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) + else: + broadcast_beta = None + if self.scale: + broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) + else: + broadcast_gamma = None + + if training_value is not False: # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(self.moving_mean, 0) if needs_broadcasting: @@ -185,72 +189,59 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access else: mean, variance = nn.moments(inputs, reduction_axes, shift=shift) - # Prepare updates if necessary. - if training_bool is not False and not self.updates: - mean_update = moving_averages.assign_moving_average( - self.moving_mean, mean, self.momentum, zero_debias=False) - variance_update = moving_averages.assign_moving_average( - self.moving_variance, variance, self.momentum, zero_debias=False) - # In the future this should be refactored into a self.add_update - # methods in order to allow for instance-based BN layer sharing - # across unrelated input streams (e.g. like in Keras). - self.updates.append(mean_update) - self.updates.append(variance_update) - - # Normalize batch. - if needs_broadcasting: - # In this case we must explictly broadcast all parameters. - broadcast_moving_mean = array_ops.reshape(self.moving_mean, - broadcast_shape) - broadcast_moving_variance = array_ops.reshape(self.moving_variance, - broadcast_shape) - if self.center: - broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) + # Prepare updates if necessary. + if not self.updates: + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, self.momentum, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, self.momentum, zero_debias=False) + # In the future this should be refactored into a self.add_update + # methods in order to allow for instance-based BN layer sharing + # across unrelated input streams (e.g. like in Keras). + self.updates.append(mean_update) + self.updates.append(variance_update) + + # Normalize batch. We do this inside separate functions for training + # and inference so as to avoid evaluating both branches. + def normalize_in_test(): + if needs_broadcasting: + broadcast_moving_mean = array_ops.reshape(self.moving_mean, + broadcast_shape) + broadcast_moving_variance = array_ops.reshape(self.moving_variance, + broadcast_shape) + return nn.batch_normalization(inputs, + broadcast_moving_mean, + broadcast_moving_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) else: - broadcast_beta = None - if self.scale: - broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) + return nn.batch_normalization(inputs, + self.moving_mean, + self.moving_variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) + + def normalize_in_training(): + if needs_broadcasting: + return nn.batch_normalization(inputs, + broadcast_mean, + broadcast_variance, + broadcast_beta, + broadcast_gamma, + self.epsilon) else: - broadcast_gamma = None + return nn.batch_normalization(inputs, + mean, + variance, + self.beta if self.center else None, + self.gamma if self.scale else None, + self.epsilon) - if training_bool is not False: - normed_inputs_training = nn.batch_normalization(inputs, - broadcast_mean, - broadcast_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - normed_inputs = nn.batch_normalization(inputs, - broadcast_moving_mean, - broadcast_moving_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - else: - # No need for broadcasting. - if training_bool is not False: - normed_inputs_training = nn.batch_normalization( - inputs, - mean, - variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) - normed_inputs = nn.batch_normalization(inputs, - self.moving_mean, - self.moving_variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) - - # Return the proper output depending on the boolean training phase. - if training_bool is True: - return normed_inputs_training - if training_bool is False: - return normed_inputs - return control_flow_ops.cond(training, - lambda: normed_inputs_training, - lambda: normed_inputs) + return utils.smart_cond(training, + normalize_in_training, + normalize_in_test) def batch_normalization(inputs, diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index f6e8ce8a28..2601c61c47 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.layers import base -from tensorflow.python.layers import conv_utils as utils +from tensorflow.python.layers import utils class _Pooling1D(base._Layer): # pylint: disable=protected-access diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py new file mode 100644 index 0000000000..650e0586c3 --- /dev/null +++ b/tensorflow/python/layers/utils.py @@ -0,0 +1,167 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# pylint: disable=unused-import,g-bad-import-order +"""Contains layer utilies for input validation and format conversion. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +from six.moves import xrange # pylint: disable=redefined-builtin +import numpy as np + +from tensorflow.python.ops import variables +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util + + +def convert_data_format(data_format, ndim): + if data_format == 'channels_last': + if ndim == 3: + return 'NWC' + elif ndim == 4: + return 'NHWC' + elif ndim == 5: + return 'NDHWC' + else: + raise ValueError('Input rank not supported:', ndim) + elif data_format == 'channels_first': + if ndim == 3: + return 'NCW' + elif ndim == 4: + return 'NCHW' + elif ndim == 5: + raise ValueError('Data format "channels_first" not supported for ' + 'inputs with rank 5.') + else: + raise ValueError('Input rank not supported:', ndim) + else: + raise ValueError('Invalid data_format:', data_format) + + +def normalize_tuple(value, n, name): + """Transforms a single integer or iterable of integers into an integer tuple. + + Arguments: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the tuple to be returned. + name: The name of the argument being validated, e.g. "strides" or + "kernel_size". This is only used to format error messages. + + Returns: + A tuple of n integers. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, int): + return (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value)) + if len(value_tuple) != n: + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value)) + for single_value in value_tuple: + try: + int(single_value) + except ValueError: + raise ValueError('The `' + name + '` argument must be a tuple of ' + + str(n) + ' integers. Received: ' + str(value) + ' ' + 'including element ' + str(single_value) + ' of type' + + ' ' + str(type(single_value))) + return value_tuple + + +def normalize_data_format(value): + data_format = value.lower() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('The `data_format` argument must be one of ' + '"channels_first", "channels_last". Received: ' + + str(value)) + return data_format + + +def normalize_padding(value): + padding = value.lower() + if padding not in {'valid', 'same'}: + raise ValueError('The `padding` argument must be one of "valid", "same". ' + 'Received: ' + str(padding)) + return padding + + +def smart_cond(pred, fn1, fn2, name=None): + """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. + + If `pred` is a bool or has a constant value, we return either `fn1()` + or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `fn1` or `fn2`. + fn1: The callable to be performed if pred is true. + fn2: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `fn1` or `fn2`. + + Raises: + TypeError is fn1 or fn2 is not callable. + """ + if not callable(fn1): + raise TypeError('`fn1` must be callable.') + if not callable(fn2): + raise TypeError('`fn2` must be callable.') + + pred_value = constant_value(pred) + if pred_value is not None: + if pred_value: + return fn1() + else: + return fn2() + else: + return control_flow_ops.cond(pred, fn1, fn2, name) + + +def constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError is pred is not a Variable, Tensor or bool. + """ + if isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, variables.Variable): + pred_value = None + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + else: + raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') + return pred_value diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py new file mode 100644 index 0000000000..d95ea04d42 --- /dev/null +++ b/tensorflow/python/layers/utils_test.py @@ -0,0 +1,66 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.layers.core.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.layers import utils +from tensorflow.python.platform import test + + +class ConvUtilsTest(test.TestCase): + + def testConvertDataFormat(self): + self.assertEqual(utils.convert_data_format('channels_first', 4), 'NCHW') + self.assertEqual(utils.convert_data_format('channels_first', 3), 'NCW') + self.assertEqual(utils.convert_data_format('channels_last', 4), 'NHWC') + self.assertEqual(utils.convert_data_format('channels_last', 3), 'NWC') + self.assertEqual(utils.convert_data_format('channels_last', 5), 'NDHWC') + + with self.assertRaises(ValueError): + utils.convert_data_format('invalid', 2) + + def testNormalizeTuple(self): + self.assertEqual(utils.normalize_tuple(2, n=3, name='strides'), (2, 2, 2)) + self.assertEqual( + utils.normalize_tuple((2, 1, 2), n=3, name='strides'), (2, 1, 2)) + + with self.assertRaises(ValueError): + utils.normalize_tuple((2, 1), n=3, name='strides') + + with self.assertRaises(ValueError): + utils.normalize_tuple(None, n=3, name='strides') + + def testNormalizeDataFormat(self): + self.assertEqual( + utils.normalize_data_format('Channels_Last'), 'channels_last') + self.assertEqual( + utils.normalize_data_format('CHANNELS_FIRST'), 'channels_first') + + with self.assertRaises(ValueError): + utils.normalize_data_format('invalid') + + def testNormalizePadding(self): + self.assertEqual(utils.normalize_padding('SAME'), 'same') + self.assertEqual(utils.normalize_padding('VALID'), 'valid') + + with self.assertRaises(ValueError): + utils.normalize_padding('invalid') + + +if __name__ == '__main__': + test.main() -- cgit v1.2.3