aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-10-08 10:43:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:49:56 -0700
commit8ef3e7c8c053cb6dad530e13c478bbd406ea2c95 (patch)
tree74f36c8bd9293854ce0ee1f8a9bac04a863bfe99 /tensorflow/python
parent153decedefc8da1fbd0717f4223b4b053e7aa517 (diff)
Part 1/3 of the feature sync to the Keras 2.2.4 API.
PiperOrigin-RevId: 216211279
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/activations.py5
-rw-r--r--tensorflow/python/keras/activations_test.py10
-rw-r--r--tensorflow/python/keras/backend.py81
-rw-r--r--tensorflow/python/keras/backend_test.py44
-rw-r--r--tensorflow/python/keras/callbacks.py4
-rw-r--r--tensorflow/python/keras/engine/network.py9
-rw-r--r--tensorflow/python/keras/layers/convolutional.py177
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py31
-rw-r--r--tensorflow/python/keras/layers/pooling.py185
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py30
-rw-r--r--tensorflow/python/keras/layers/wrappers.py3
-rw-r--r--tensorflow/python/keras/testing_utils.py5
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py45
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py17
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py26
-rw-r--r--tensorflow/python/keras/utils/np_utils.py5
16 files changed, 539 insertions, 138 deletions
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 99645de736..d69791ce8d 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -160,6 +160,11 @@ def sigmoid(x):
return nn.sigmoid(x)
+@tf_export('keras.activations.exponential')
+def exponential(x):
+ return math_ops.exp(x)
+
+
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
"""Hard sigmoid activation function.
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index dd0bbcff39..ad238cb0a9 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -169,6 +169,16 @@ class KerasActivationsTest(test.TestCase):
expected = np.tanh(test_values)
self.assertAllClose(result, expected, rtol=1e-05)
+ def test_exponential(self):
+ with self.cached_session():
+ test_values = np.random.random((2, 5))
+ x = keras.backend.placeholder(ndim=2)
+ exp = keras.activations.exponential(x)
+ f = keras.backend.function([x], [exp])
+ result = f([test_values])[0]
+ expected = np.exp(test_values)
+ self.assertAllClose(result, expected, rtol=1e-05)
+
def test_linear(self):
x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x))
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 63e776a06b..13f52fbae7 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2223,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@tf_export('keras.backend.batch_normalization')
-def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
+def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns:
@@ -2235,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch.
beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input.
+ axis: Integer, the axis that should be normalized.
+ (typically the features axis).
epsilon: Fuzz factor.
Returns:
A tensor.
"""
+ if ndim(x) == 4:
+ # The CPU implementation of `fused_batch_norm` only supports NHWC
+ if axis == 1 or axis == -3:
+ tf_data_format = 'NCHW'
+ elif axis == 3 or axis == -1:
+ tf_data_format = 'NHWC'
+ else:
+ tf_data_format = None
+
+ if (tf_data_format == 'NHWC' or
+ tf_data_format == 'NCHW' and _has_nchw_support()):
+ # The mean / var / beta / gamma tensors may be broadcasted
+ # so they may have extra axes of size 1, which should be squeezed.
+ if ndim(mean) > 1:
+ mean = array_ops.reshape(mean, [-1])
+ if ndim(var) > 1:
+ var = array_ops.reshape(var, [-1])
+ if beta is None:
+ beta = zeros_like(mean)
+ elif ndim(beta) > 1:
+ beta = array_ops.reshape(beta, [-1])
+ if gamma is None:
+ gamma = ones_like(mean)
+ elif ndim(gamma) > 1:
+ gamma = array_ops.reshape(gamma, [-1])
+ y, _, _ = nn.fused_batch_norm(
+ x,
+ gamma,
+ beta,
+ epsilon=epsilon,
+ mean=mean,
+ variance=var,
+ data_format=tf_data_format,
+ is_training=False
+ )
+ return y
return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@@ -2880,7 +2918,7 @@ class Function(object):
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
- 'time: %s', session_kwargs.keys())
+ 'time: %s', (session_kwargs.keys(),))
self._callable_fn = None
self._feed_arrays = None
@@ -3798,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format):
return x, tf_data_format
-def _preprocess_conv2d_input(x, data_format):
+def _preprocess_conv2d_input(x, data_format, force_transpose=False):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
+ force_transpose: Boolean. If True, the input will always be transposed
+ from NCHW to NHWC if `data_format` is `"channels_first"`.
+ If False, the transposition only occurs on CPU (GPU ops are
+ assumed to support NCHW).
Returns:
A tensor.
"""
tf_data_format = 'NHWC'
if data_format == 'channels_first':
- if not _has_nchw_support():
+ if not _has_nchw_support() or force_transpose:
x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = 'NCHW'
@@ -3958,7 +4000,8 @@ def conv2d_transpose(x,
output_shape,
strides=(1, 1),
padding='valid',
- data_format=None):
+ data_format=None,
+ dilation_rate=(1, 1)):
"""2D deconvolution (i.e.
transposed convolution).
@@ -3972,6 +4015,7 @@ def conv2d_transpose(x,
data_format: string, `"channels_last"` or `"channels_first"`.
Whether to use Theano or TensorFlow/CNTK data format
for inputs/kernels/outputs.
+ dilation_rate: Tuple of 2 integers.
Returns:
A tensor, result of transposed 2D convolution.
@@ -3987,7 +4031,13 @@ def conv2d_transpose(x,
if isinstance(output_shape, (tuple, list)):
output_shape = array_ops.stack(output_shape)
- x, tf_data_format = _preprocess_conv2d_input(x, data_format)
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
+ if data_format == 'channels_first' and dilation_rate != (1, 1):
+ force_transpose = True
+ else:
+ force_transpose = False
+
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
output_shape = (output_shape[0], output_shape[2], output_shape[3],
@@ -4002,13 +4052,18 @@ def conv2d_transpose(x,
else:
strides = (1, 1) + strides
- x = nn.conv2d_transpose(
- x,
- kernel,
- output_shape,
- strides,
- padding=padding,
- data_format=tf_data_format)
+ if dilation_rate == (1, 1):
+ x = nn.conv2d_transpose(x, kernel, output_shape, strides,
+ padding=padding,
+ data_format=tf_data_format)
+ else:
+ assert dilation_rate[0] == dilation_rate[1]
+ x = nn.atrous_conv2d_transpose(
+ x,
+ kernel,
+ output_shape,
+ rate=dilation_rate[0],
+ padding=padding)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ab71589940..0834448699 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import nn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
+ def test_batch_normalization(self):
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+
+ # 3D NHC case
+ val = np.random.random((10, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 3])
+
+ # 4D NHWC case
+ val = np.random.random((10, 5, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1, 2), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
+
+ # 4D NCHW case
+ val = np.random.random((10, 3, 5, 5))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 2, 3), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
+
class TestCTC(test.TestCase):
@@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.min(y), -2., atol=0.1)
def test_string_input(self):
- seq = keras.Sequential([
- keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
- keras.layers.Lambda(lambda x: x[0])
- ])
- preds = seq.predict([['tensorflow eager']])
- self.assertEqual(preds.shape, (1,))
+ with self.cached_session():
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6dfbbf3694..3d6000f223 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -781,6 +781,10 @@ class LearningRateScheduler(Callback):
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ logs['lr'] = K.get_value(self.model.optimizer.lr)
+
@tf_export('keras.callbacks.TensorBoard')
class TensorBoard(Callback):
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 918488bd7a..5969fea2b2 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1641,10 +1641,11 @@ class Network(base_layer.Layer):
ValueError: if `summary()` is called before the model is built.
"""
if not self.built:
- raise ValueError('This model has never been called, thus its weights '
- 'have not yet been created, so no summary can be '
- 'displayed. Build the model first '
- '(e.g. by calling it on some data).')
+ raise ValueError('This model has not yet been built. '
+ 'Build the model first by calling `build()` or calling '
+ '`fit()` with some data, or specify '
+ 'an `input_shape` argument in the first layer(s) for '
+ 'automatic build.')
layer_utils.print_summary(self,
line_length=line_length,
positions=positions,
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index d00def07bb..8f5872385c 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -645,6 +645,14 @@ class Conv2DTranspose(Conv2D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 2 integers,
+ specifying the amount of padding along the height and width
+ of the output tensor.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -700,7 +708,9 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
+ dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@@ -717,6 +727,7 @@ class Conv2DTranspose(Conv2D):
strides=strides,
padding=padding,
data_format=data_format,
+ dilation_rate=dilation_rate,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=initializers.get(kernel_initializer),
@@ -728,6 +739,16 @@ class Conv2DTranspose(Conv2D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 2, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
@@ -769,51 +790,50 @@ class Conv2DTranspose(Conv2D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
+ h_axis, w_axis = 2, 3
else:
- c_axis, h_axis, w_axis = 3, 1, 2
+ h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
- strides = (1, 1, stride_h, stride_w)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
- strides = (1, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv2d_transpose(
+ outputs = backend.conv2d_transpose(
inputs,
self.kernel,
output_shape_tensor,
- strides,
- padding=self.padding.upper(),
- data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dilation_rate=self.dilation_rate)
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -837,13 +857,33 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv2DTranspose, self).get_config()
+ config['output_padding'] = self.output_padding
+ return config
+
@tf_export('keras.layers.Conv3DTranspose',
'keras.layers.Convolution3DTranspose')
@@ -878,6 +918,14 @@ class Conv3DTranspose(Conv3D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 3 integers,
+ specifying the amount of padding along the depth, height, and
+ width.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -943,6 +991,7 @@ class Conv3DTranspose(Conv3D):
kernel_size,
strides=(1, 1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
activation=None,
use_bias=True,
@@ -971,6 +1020,16 @@ class Conv3DTranspose(Conv3D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 3, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 5:
@@ -1012,11 +1071,9 @@ class Conv3DTranspose(Conv3D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
+ d_axis, h_axis, w_axis = 2, 3, 4
else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- self.input_spec = InputSpec(ndim=5, axes={c_axis: inputs_shape[c_axis]})
+ d_axis, h_axis, w_axis = 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
@@ -1025,19 +1082,27 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_depth = conv_utils.deconv_output_length(depth,
kernel_d,
- self.padding,
- stride_d)
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
@@ -1058,20 +1123,7 @@ class Conv3DTranspose(Conv3D):
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[d_axis] = conv_utils.deconv_output_length(out_shape[d_axis],
- kernel_d,
- self.padding,
- stride_d)
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -1109,15 +1161,38 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[d_axis] = conv_utils.deconv_output_length(
- output_shape[d_axis], kernel_d, self.padding, stride_d)
+ output_shape[d_axis],
+ kernel_d,
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv3DTranspose, self).get_config()
+ config.pop('dilation_rate')
+ config['output_padding'] = self.output_padding
+ return config
+
class SeparableConv(Conv):
"""Abstract base layer for separable nD convolution.
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index cad5e4c8bd..f88d632ab5 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -204,6 +204,9 @@ class Conv2DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1)])
+
def test_conv2dtranspose_regularizers(self):
kwargs = {
'filters': 3,
@@ -239,6 +242,31 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(layer.kernel.constraint, k_constraint)
self.assertEqual(layer.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_conv2d_transpose_dilation(self):
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ kwargs={'filters': 2,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2)},
+ input_shape=(2, 5, 6, 3))
+
+ input_data = np.arange(48).reshape((1, 4, 4, 3)).astype(np.float32)
+ expected_output = np.float32([[192, 228, 192, 228],
+ [336, 372, 336, 372],
+ [192, 228, 192, 228],
+ [336, 372, 336, 372]]).reshape((1, 4, 4, 1))
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ input_data=input_data,
+ kwargs={'filters': 1,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2),
+ 'kernel_initializer': 'ones'},
+ expected_output=expected_output)
+
class Conv3DTransposeTest(test.TestCase):
@@ -270,6 +298,9 @@ class Conv3DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1, 1)])
+
def test_conv3dtranspose_regularizers(self):
kwargs = {
'filters': 3,
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 912e8bd619..72a9c1d629 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -41,16 +44,18 @@ class Pooling1D(Layer):
strides of the pooling operation.
padding: A string. The padding method, either 'valid' or 'same'.
Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
name: A string, the name of the layer.
"""
def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format=None,
+ padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
@@ -65,45 +70,39 @@ class Pooling1D(Layer):
self.input_spec = InputSpec(ndim=3)
def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
+ pad_axis = 2 if self.data_format == 'channels_last' else 3
+ inputs = array_ops.expand_dims(inputs, pad_axis)
outputs = self.pool_function(
inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
+ self.pool_size + (1,),
+ strides=self.strides + (1,),
+ padding=self.padding,
+ data_format=self.data_format)
+ return array_ops.squeeze(outputs, pad_axis)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
+ if self.data_format == 'channels_first':
+ steps = input_shape[2]
+ features = input_shape[1]
+ else:
+ steps = input_shape[1]
+ features = input_shape[2]
+ length = conv_utils.conv_output_length(steps,
+ self.pool_size[0],
+ self.padding,
+ self.strides[0])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], features, length])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], length, features])
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
- 'padding': self.padding
+ 'padding': self.padding,
+ 'data_format': self.data_format,
}
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(MaxPooling1D, self).__init__(
- nn.max_pool,
+ functools.partial(backend.pool2d, pool_mode='max'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
+ functools.partial(backend.pool2d, pool_mode='avg'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -561,41 +594,96 @@ class GlobalPooling1D(Layer):
"""Abstract class for different global pooling 1D layers.
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format='channels_last', **kwargs):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
+ self.data_format = conv_utils.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
def call(self, inputs):
raise NotImplementedError
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(GlobalPooling1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.GlobalAveragePooling1D',
'keras.layers.GlobalAvgPool1D')
class GlobalAveragePooling1D(GlobalPooling1D):
"""Global average pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
`(batch_size, features)`
"""
- def call(self, inputs):
- return backend.mean(inputs, axis=1)
+ def __init__(self, data_format='channels_last', **kwargs):
+ super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
+ **kwargs)
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ if mask is not None:
+ mask = math_ops.cast(mask, backend.floatx())
+ input_shape = inputs.shape.as_list()
+ broadcast_shape = [-1, input_shape[steps_axis], 1]
+ mask = array_ops.reshape(mask, broadcast_shape)
+ inputs *= mask
+ return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
+ mask, axis=steps_axis)
+ else:
+ return backend.mean(inputs, axis=steps_axis)
+
+ def compute_mask(self, inputs, mask=None):
+ return None
@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
class GlobalMaxPooling1D(GlobalPooling1D):
"""Global max pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
@@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D):
"""
def call(self, inputs):
- return backend.max(inputs, axis=1)
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ return backend.max(inputs, axis=steps_axis)
class GlobalPooling2D(Layer):
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 2cd9939e66..936e73ecf9 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
class GlobalPoolingTest(test.TestCase):
@@ -31,8 +34,26 @@ class GlobalPoolingTest(test.TestCase):
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_globalpooling_1d_masking_support(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Masking(mask_value=0., input_shape=(3, 4)))
+ model.add(keras.layers.GlobalAveragePooling1D())
+ model.compile(loss='mae', optimizer=rmsprop.RMSPropOptimizer(0.001))
+
+ model_input = np.random.random((2, 3, 4))
+ model_input[0, 1:, :] = 0
+ output = model.predict(model_input)
+ self.assertAllClose(output[0], model_input[0, 0, :])
@tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
@@ -172,6 +193,10 @@ class Pooling1DTest(test.TestCase):
kwargs={'strides': stride,
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.MaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
@tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
@@ -183,6 +208,11 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.AveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index a1933c11b0..d19d0b5f8c 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -587,6 +587,9 @@ class Bidirectional(Wrapper):
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
+ else:
+ raise ValueError(
+ 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 501b50ba5f..2fae094a1e 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -166,8 +166,9 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
if expected_dim is not None:
if expected_dim != actual_dim:
raise AssertionError(
- 'When testing layer %s, for input %s, found output_shape='
- '%s but expected to find %s.\nFull kwargs: %s' %
+ 'When testing layer %s **after deserialization**, '
+ 'for input %s, found output_shape='
+ '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
(layer_cls.__name__,
x,
actual_output_shape,
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 8ebca1418d..f486e631e5 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride):
return (output_length - 1) * stride - 2 * pad + filter_size
-def deconv_output_length(input_length, filter_size, padding, stride):
+def deconv_output_length(input_length, filter_size, padding,
+ output_padding=None, stride=0, dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
+ input_length: Integer.
+ filter_size: Integer.
+ padding: one of `"same"`, `"valid"`, `"full"`.
+ output_padding: Integer, amount of padding along the output dimension.
+ Can be set to `None` in which case the output length is inferred.
+ stride: Integer.
+ dilation: Integer.
Returns:
The output length (integer).
"""
+ assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
- input_length *= stride
- if padding == 'valid':
- input_length += max(filter_size - stride, 0)
- elif padding == 'full':
- input_length -= (stride + filter_size - 2)
- return input_length
+
+ # Get the dilated kernel size
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+
+ # Infer length if output padding is None, else compute the exact length
+ if output_padding is None:
+ if padding == 'valid':
+ length = input_length * stride + max(filter_size - stride, 0)
+ elif padding == 'full':
+ length = input_length * stride - (stride + filter_size - 2)
+ elif padding == 'same':
+ length = input_length * stride
+
+ else:
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+
+ length = ((input_length - 1) * stride + filter_size - 2 * pad +
+ output_padding)
+ return length
def normalize_data_format(value):
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e1c49bc852..04b2ea8fe3 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
+ # Deduplicate output names to handle Siamese networks.
+ occurrences = {}
+ for n in model.output_names:
+ if n not in occurrences:
+ occurrences[n] = 1
+ else:
+ occurrences[n] += 1
+ conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1}
+ output_names = []
+ for n in model.output_names:
+ if n in conflict_counter:
+ conflict_counter[n] += 1
+ n += '_%d' % conflict_counter[n]
+ output_names.append(n)
+
# Merge outputs under expected scope.
with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
- for name, outputs in zip(model.output_names, all_outputs):
+ for name, outputs in zip(output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 3d0351a11f..1780ab6587 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase):
parallel_model.compile(loss='mean_squared_error', optimizer='adam')
parallel_model.train_on_batch(x, y)
+ def test_multi_gpu_with_siamese_network(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.cached_session():
+ input_shape = (3,)
+ nested_model = keras.models.Sequential([
+ keras.layers.Dense(32, input_shape=input_shape),
+ keras.layers.Dense(1)
+ ], name='nested')
+
+ input1 = keras.Input(input_shape)
+ input2 = keras.Input(input_shape)
+ score1 = nested_model(input1)
+ score2 = nested_model(input2)
+ score_sum = keras.layers.Add(name='add')([score1, score2])
+
+ siamese = keras.models.Model(inputs=[input1, input2],
+ outputs=[score_sum, score1, score2],
+ name='siamese')
+ parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus)
+ self.assertEqual(parallel_siamese.output_names,
+ ['add', 'nested_1', 'nested_2'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index c24e87308b..3763999bff 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.utils.to_categorical')
-def to_categorical(y, num_classes=None):
+def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
@@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None):
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
+ dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
@@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)