aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py81
1 files changed, 68 insertions, 13 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 63e776a06b..13f52fbae7 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2223,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@tf_export('keras.backend.batch_normalization')
-def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
+def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns:
@@ -2235,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch.
beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input.
+ axis: Integer, the axis that should be normalized.
+ (typically the features axis).
epsilon: Fuzz factor.
Returns:
A tensor.
"""
+ if ndim(x) == 4:
+ # The CPU implementation of `fused_batch_norm` only supports NHWC
+ if axis == 1 or axis == -3:
+ tf_data_format = 'NCHW'
+ elif axis == 3 or axis == -1:
+ tf_data_format = 'NHWC'
+ else:
+ tf_data_format = None
+
+ if (tf_data_format == 'NHWC' or
+ tf_data_format == 'NCHW' and _has_nchw_support()):
+ # The mean / var / beta / gamma tensors may be broadcasted
+ # so they may have extra axes of size 1, which should be squeezed.
+ if ndim(mean) > 1:
+ mean = array_ops.reshape(mean, [-1])
+ if ndim(var) > 1:
+ var = array_ops.reshape(var, [-1])
+ if beta is None:
+ beta = zeros_like(mean)
+ elif ndim(beta) > 1:
+ beta = array_ops.reshape(beta, [-1])
+ if gamma is None:
+ gamma = ones_like(mean)
+ elif ndim(gamma) > 1:
+ gamma = array_ops.reshape(gamma, [-1])
+ y, _, _ = nn.fused_batch_norm(
+ x,
+ gamma,
+ beta,
+ epsilon=epsilon,
+ mean=mean,
+ variance=var,
+ data_format=tf_data_format,
+ is_training=False
+ )
+ return y
return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@@ -2880,7 +2918,7 @@ class Function(object):
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
- 'time: %s', session_kwargs.keys())
+ 'time: %s', (session_kwargs.keys(),))
self._callable_fn = None
self._feed_arrays = None
@@ -3798,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format):
return x, tf_data_format
-def _preprocess_conv2d_input(x, data_format):
+def _preprocess_conv2d_input(x, data_format, force_transpose=False):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
+ force_transpose: Boolean. If True, the input will always be transposed
+ from NCHW to NHWC if `data_format` is `"channels_first"`.
+ If False, the transposition only occurs on CPU (GPU ops are
+ assumed to support NCHW).
Returns:
A tensor.
"""
tf_data_format = 'NHWC'
if data_format == 'channels_first':
- if not _has_nchw_support():
+ if not _has_nchw_support() or force_transpose:
x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = 'NCHW'
@@ -3958,7 +4000,8 @@ def conv2d_transpose(x,
output_shape,
strides=(1, 1),
padding='valid',
- data_format=None):
+ data_format=None,
+ dilation_rate=(1, 1)):
"""2D deconvolution (i.e.
transposed convolution).
@@ -3972,6 +4015,7 @@ def conv2d_transpose(x,
data_format: string, `"channels_last"` or `"channels_first"`.
Whether to use Theano or TensorFlow/CNTK data format
for inputs/kernels/outputs.
+ dilation_rate: Tuple of 2 integers.
Returns:
A tensor, result of transposed 2D convolution.
@@ -3987,7 +4031,13 @@ def conv2d_transpose(x,
if isinstance(output_shape, (tuple, list)):
output_shape = array_ops.stack(output_shape)
- x, tf_data_format = _preprocess_conv2d_input(x, data_format)
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
+ if data_format == 'channels_first' and dilation_rate != (1, 1):
+ force_transpose = True
+ else:
+ force_transpose = False
+
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
output_shape = (output_shape[0], output_shape[2], output_shape[3],
@@ -4002,13 +4052,18 @@ def conv2d_transpose(x,
else:
strides = (1, 1) + strides
- x = nn.conv2d_transpose(
- x,
- kernel,
- output_shape,
- strides,
- padding=padding,
- data_format=tf_data_format)
+ if dilation_rate == (1, 1):
+ x = nn.conv2d_transpose(x, kernel, output_shape, strides,
+ padding=padding,
+ data_format=tf_data_format)
+ else:
+ assert dilation_rate[0] == dilation_rate[1]
+ x = nn.atrous_conv2d_transpose(
+ x,
+ kernel,
+ output_shape,
+ rate=dilation_rate[0],
+ padding=padding)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x