diff options
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 81 |
1 files changed, 68 insertions, 13 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 63e776a06b..13f52fbae7 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -2223,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): @tf_export('keras.backend.batch_normalization') -def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): +def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): """Applies batch normalization on x given mean, var, beta and gamma. I.e. returns: @@ -2235,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): var: Variance of batch. beta: Tensor with which to center the input. gamma: Tensor by which to scale the input. + axis: Integer, the axis that should be normalized. + (typically the features axis). epsilon: Fuzz factor. Returns: A tensor. """ + if ndim(x) == 4: + # The CPU implementation of `fused_batch_norm` only supports NHWC + if axis == 1 or axis == -3: + tf_data_format = 'NCHW' + elif axis == 3 or axis == -1: + tf_data_format = 'NHWC' + else: + tf_data_format = None + + if (tf_data_format == 'NHWC' or + tf_data_format == 'NCHW' and _has_nchw_support()): + # The mean / var / beta / gamma tensors may be broadcasted + # so they may have extra axes of size 1, which should be squeezed. + if ndim(mean) > 1: + mean = array_ops.reshape(mean, [-1]) + if ndim(var) > 1: + var = array_ops.reshape(var, [-1]) + if beta is None: + beta = zeros_like(mean) + elif ndim(beta) > 1: + beta = array_ops.reshape(beta, [-1]) + if gamma is None: + gamma = ones_like(mean) + elif ndim(gamma) > 1: + gamma = array_ops.reshape(gamma, [-1]) + y, _, _ = nn.fused_batch_norm( + x, + gamma, + beta, + epsilon=epsilon, + mean=mean, + variance=var, + data_format=tf_data_format, + is_training=False + ) + return y return nn.batch_normalization(x, mean, var, beta, gamma, epsilon) @@ -2880,7 +2918,7 @@ class Function(object): if session_kwargs: raise ValueError('Some keys in session_kwargs are not supported at this ' - 'time: %s', session_kwargs.keys()) + 'time: %s', (session_kwargs.keys(),)) self._callable_fn = None self._feed_arrays = None @@ -3798,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format): return x, tf_data_format -def _preprocess_conv2d_input(x, data_format): +def _preprocess_conv2d_input(x, data_format, force_transpose=False): """Transpose and cast the input before the conv2d. Arguments: x: input tensor. data_format: string, `"channels_last"` or `"channels_first"`. + force_transpose: Boolean. If True, the input will always be transposed + from NCHW to NHWC if `data_format` is `"channels_first"`. + If False, the transposition only occurs on CPU (GPU ops are + assumed to support NCHW). Returns: A tensor. """ tf_data_format = 'NHWC' if data_format == 'channels_first': - if not _has_nchw_support(): + if not _has_nchw_support() or force_transpose: x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC else: tf_data_format = 'NCHW' @@ -3958,7 +4000,8 @@ def conv2d_transpose(x, output_shape, strides=(1, 1), padding='valid', - data_format=None): + data_format=None, + dilation_rate=(1, 1)): """2D deconvolution (i.e. transposed convolution). @@ -3972,6 +4015,7 @@ def conv2d_transpose(x, data_format: string, `"channels_last"` or `"channels_first"`. Whether to use Theano or TensorFlow/CNTK data format for inputs/kernels/outputs. + dilation_rate: Tuple of 2 integers. Returns: A tensor, result of transposed 2D convolution. @@ -3987,7 +4031,13 @@ def conv2d_transpose(x, if isinstance(output_shape, (tuple, list)): output_shape = array_ops.stack(output_shape) - x, tf_data_format = _preprocess_conv2d_input(x, data_format) + # `atrous_conv2d_transpose` only supports NHWC format, even on GPU. + if data_format == 'channels_first' and dilation_rate != (1, 1): + force_transpose = True + else: + force_transpose = False + + x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose) if data_format == 'channels_first' and tf_data_format == 'NHWC': output_shape = (output_shape[0], output_shape[2], output_shape[3], @@ -4002,13 +4052,18 @@ def conv2d_transpose(x, else: strides = (1, 1) + strides - x = nn.conv2d_transpose( - x, - kernel, - output_shape, - strides, - padding=padding, - data_format=tf_data_format) + if dilation_rate == (1, 1): + x = nn.conv2d_transpose(x, kernel, output_shape, strides, + padding=padding, + data_format=tf_data_format) + else: + assert dilation_rate[0] == dilation_rate[1] + x = nn.atrous_conv2d_transpose( + x, + kernel, + output_shape, + rate=dilation_rate[0], + padding=padding) if data_format == 'channels_first' and tf_data_format == 'NHWC': x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW return x |