diff options
Diffstat (limited to 'tensorflow/python/keras/utils/np_utils.py')
-rw-r--r-- | tensorflow/python/keras/utils/np_utils.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py index c24e87308b..3763999bff 100644 --- a/tensorflow/python/keras/utils/np_utils.py +++ b/tensorflow/python/keras/utils/np_utils.py @@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('keras.utils.to_categorical') -def to_categorical(y, num_classes=None): +def to_categorical(y, num_classes=None, dtype='float32'): """Converts a class vector (integers) to binary class matrix. E.g. for use with categorical_crossentropy. @@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None): y: class vector to be converted into a matrix (integers from 0 to num_classes). num_classes: total number of classes. + dtype: The data type expected by the input. Default: `'float32'`. Returns: A binary matrix representation of the input. The classes axis is placed @@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None): if not num_classes: num_classes = np.max(y) + 1 n = y.shape[0] - categorical = np.zeros((n, num_classes), dtype=np.float32) + categorical = np.zeros((n, num_classes), dtype=dtype) categorical[np.arange(n), y] = 1 output_shape = input_shape + (num_classes,) categorical = np.reshape(categorical, output_shape) |