aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 486ef6efb9..73e51aab7d 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -466,7 +466,7 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
output of `softmax`, as it will produce incorrect results.
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
- and the same dtype (either `float32` or `float64`).
+ and the same dtype (either `float16`, `float32`, or `float64`).
Args:
logits: Unscaled log probabilities.
@@ -481,11 +481,18 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
# could break users who call this with bad labels, but disregard the bad
# results.
+ logits = ops.convert_to_tensor(logits)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ logits.dtype == dtypes.float16) else logits
+
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
- logits, labels, name=name)
- return cost
+ precise_logits, labels, name=name)
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
@@ -536,6 +543,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
"SparseSoftmaxCrossEntropyWithLogits"):
labels = ops.convert_to_tensor(labels)
logits = ops.convert_to_tensor(logits)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits
# Store label shape for result later.
labels_static_shape = labels.get_shape()
@@ -552,20 +561,27 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
# Check if no reshapes are required.
if logits.get_shape().ndims == 2:
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
- logits, labels, name=name)
- return cost
+ precise_logits, labels, name=name)
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
+
# Reshape logits to 2 dim, labels to 1 dim.
num_classes = array_ops.gather(array_ops.shape(logits),
array_ops.rank(logits) - 1)
- logits = array_ops.reshape(logits, [-1, num_classes])
+ precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
labels = array_ops.reshape(labels, [-1])
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
- logits, labels, name=name)
+ precise_logits, labels, name=name)
cost = array_ops.reshape(cost, labels_shape)
cost.set_shape(labels_static_shape)
- return cost
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")