aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/np_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/np_utils.py')
-rw-r--r--tensorflow/python/keras/utils/np_utils.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index c24e87308b..3763999bff 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.utils.to_categorical')
-def to_categorical(y, num_classes=None):
+def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
@@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None):
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
+ dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
@@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)