diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 486ef6efb9..73e51aab7d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -466,7 +466,7 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): output of `softmax`, as it will produce incorrect results. `logits` and `labels` must have the same shape `[batch_size, num_classes]` - and the same dtype (either `float32` or `float64`). + and the same dtype (either `float16`, `float32`, or `float64`). Args: logits: Unscaled log probabilities. @@ -481,11 +481,18 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): # could break users who call this with bad labels, but disregard the bad # results. + logits = ops.convert_to_tensor(logits) + precise_logits = math_ops.cast(logits, dtypes.float32) if ( + logits.dtype == dtypes.float16) else logits + # The second output tensor contains the gradients. We use it in # _CrossEntropyGrad() in nn_grad but not here. cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits( - logits, labels, name=name) - return cost + precise_logits, labels, name=name) + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): @@ -536,6 +543,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): "SparseSoftmaxCrossEntropyWithLogits"): labels = ops.convert_to_tensor(labels) logits = ops.convert_to_tensor(logits) + precise_logits = math_ops.cast(logits, dtypes.float32) if ( + dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits # Store label shape for result later. labels_static_shape = labels.get_shape() @@ -552,20 +561,27 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): # Check if no reshapes are required. if logits.get_shape().ndims == 2: cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - logits, labels, name=name) - return cost + precise_logits, labels, name=name) + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost + # Reshape logits to 2 dim, labels to 1 dim. num_classes = array_ops.gather(array_ops.shape(logits), array_ops.rank(logits) - 1) - logits = array_ops.reshape(logits, [-1, num_classes]) + precise_logits = array_ops.reshape(precise_logits, [-1, num_classes]) labels = array_ops.reshape(labels, [-1]) # The second output tensor contains the gradients. We use it in # _CrossEntropyGrad() in nn_grad but not here. cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - logits, labels, name=name) + precise_logits, labels, name=name) cost = array_ops.reshape(cost, labels_shape) cost.set_shape(labels_static_shape) - return cost + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost @ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits") |