aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-07-22 12:54:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-22 14:04:04 -0700
commit075868639f79d77657a7d268fc4762d7a90c370b (patch)
tree08fcb3e79bde8fa9a0006be04233f8067bd29b1e /tensorflow
parentef5d941c164b22a9be47e4f5bd7c90ba7c83e984 (diff)
Use 32bit floats to compute cross entropies. 16 bit floats aren't accurate
enough to deal with more than a few labels Change: 128208074
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/kernel_tests/sparse_xent_op_test.py17
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py21
-rw-r--r--tensorflow/python/ops/nn_ops.py32
-rw-r--r--tensorflow/python/ops/nn_xent_test.py27
4 files changed, 56 insertions, 41 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 93a9a2667e..dbb0bac3f1 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -24,6 +24,7 @@ import time
import numpy as np
import tensorflow as tf
+from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
@@ -31,7 +32,6 @@ from tensorflow.python.ops import sparse_ops
class SparseXentTest(tf.test.TestCase):
def _npXent(self, features, labels):
- is_higher_dim = len(features.shape) > 2
features = np.reshape(features, [-1, features.shape[-1]])
labels = np.reshape(labels, [-1])
batch_dim = 0
@@ -44,15 +44,13 @@ class SparseXentTest(tf.test.TestCase):
labels_mat[np.arange(batch_size), labels] = 1.0
bp = (probs - labels_mat)
l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1)
- return l, bp, is_higher_dim
+ return l, bp
def _testXent(self, np_features, np_labels, use_gpu=False):
- np_loss, np_backprop, is_higher_dim = self._npXent(np_features, np_labels)
+ np_loss, np_backprop = self._npXent(np_features, np_labels)
with self.test_session(use_gpu=use_gpu) as sess:
- loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
np_features, np_labels)
- backprop = (loss.op.inputs[0].op.outputs[1] if is_higher_dim
- else loss.op.outputs[1])
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@@ -64,10 +62,9 @@ class SparseXentTest(tf.test.TestCase):
def _testSingleClass(self, use_gpu=False):
for label_dtype in np.int32, np.int64:
with self.test_session(use_gpu=use_gpu) as sess:
- loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(np.float32),
np.array([0, 0, 0]).astype(label_dtype))
- backprop = loss.op.outputs[1]
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
@@ -101,7 +98,7 @@ class SparseXentTest(tf.test.TestCase):
# With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644]
# The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)]
# = [1.3862, 3.4420]
- np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels))
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439]]),
np_backprop,
@@ -169,7 +166,7 @@ class SparseXentTest(tf.test.TestCase):
self.assertLess(err, 5e-8)
def _testHighDim(self, use_gpu, features, labels):
- np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels))
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
# manually reshape loss
np_loss = np.reshape(np_loss, np.array(labels).shape)
with self.test_session(use_gpu=use_gpu) as sess:
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 9e7c563547..70b3bdcbb4 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
+from tensorflow.python.ops import gen_nn_ops
+
class XentTest(tf.test.TestCase):
@@ -38,8 +40,8 @@ class XentTest(tf.test.TestCase):
def _testXent(self, np_features, np_labels, use_gpu=False):
np_loss, np_backprop = self._npXent(np_features, np_labels)
with self.test_session(use_gpu=use_gpu) as sess:
- loss = tf.nn.softmax_cross_entropy_with_logits(np_features, np_labels)
- backprop = loss.op.outputs[1]
+ loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+ np_features, np_labels)
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@@ -51,10 +53,9 @@ class XentTest(tf.test.TestCase):
def _testSingleClass(self, use_gpu=False):
for dtype in np.float16, np.float32:
with self.test_session(use_gpu=use_gpu) as sess:
- loss = tf.nn.softmax_cross_entropy_with_logits(
+ loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(dtype),
np.array([[-1.], [0.], [1.]]).astype(dtype))
- backprop = loss.op.outputs[1]
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop)
@@ -69,9 +70,9 @@ class XentTest(tf.test.TestCase):
[[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(dtype)
np_labels = np.array(
[[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype)
- self.assertRaisesRegexp(
- ValueError, "must have rank 2",
- tf.nn.softmax_cross_entropy_with_logits, np_features, np_labels)
+ self.assertRaisesRegexp(ValueError, "must have rank 2",
+ gen_nn_ops._softmax_cross_entropy_with_logits,
+ np_features, np_labels)
def testNpXent(self):
# We create 2 batches of logits for testing.
@@ -110,14 +111,14 @@ class XentTest(tf.test.TestCase):
def testShapeMismatch(self):
with self.test_session():
with self.assertRaises(ValueError):
- tf.nn.softmax_cross_entropy_with_logits(
+ gen_nn_ops._softmax_cross_entropy_with_logits(
[[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
def testNotMatrix(self):
with self.test_session():
with self.assertRaises(ValueError):
- tf.nn.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
- [0., 1., 0., 1.])
+ gen_nn_ops._softmax_cross_entropy_with_logits([0., 1., 2., 3.],
+ [0., 1., 0., 1.])
def testHalf(self):
self._testAll(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 486ef6efb9..73e51aab7d 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -466,7 +466,7 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
output of `softmax`, as it will produce incorrect results.
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
- and the same dtype (either `float32` or `float64`).
+ and the same dtype (either `float16`, `float32`, or `float64`).
Args:
logits: Unscaled log probabilities.
@@ -481,11 +481,18 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
# could break users who call this with bad labels, but disregard the bad
# results.
+ logits = ops.convert_to_tensor(logits)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ logits.dtype == dtypes.float16) else logits
+
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
- logits, labels, name=name)
- return cost
+ precise_logits, labels, name=name)
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
@@ -536,6 +543,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
"SparseSoftmaxCrossEntropyWithLogits"):
labels = ops.convert_to_tensor(labels)
logits = ops.convert_to_tensor(logits)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits
# Store label shape for result later.
labels_static_shape = labels.get_shape()
@@ -552,20 +561,27 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
# Check if no reshapes are required.
if logits.get_shape().ndims == 2:
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
- logits, labels, name=name)
- return cost
+ precise_logits, labels, name=name)
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
+
# Reshape logits to 2 dim, labels to 1 dim.
num_classes = array_ops.gather(array_ops.shape(logits),
array_ops.rank(logits) - 1)
- logits = array_ops.reshape(logits, [-1, num_classes])
+ precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
labels = array_ops.reshape(labels, [-1])
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
- logits, labels, name=name)
+ precise_logits, labels, name=name)
cost = array_ops.reshape(cost, labels_shape)
cost.set_shape(labels_static_shape)
- return cost
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")
diff --git a/tensorflow/python/ops/nn_xent_test.py b/tensorflow/python/ops/nn_xent_test.py
index c09d6141bb..10c9f8ca87 100644
--- a/tensorflow/python/ops/nn_xent_test.py
+++ b/tensorflow/python/ops/nn_xent_test.py
@@ -56,22 +56,23 @@ class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase):
def testLogisticOutput(self):
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
- logits, targets, losses = self._Inputs(dtype=tf.float32)
- loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
- np_loss = np.array(losses).astype(np.float32)
- tf_loss = loss.eval()
- self.assertAllClose(np_loss, tf_loss, atol=0.001)
+ for dtype in [tf.float32, tf.float16]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=dtype)
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
def testLogisticOutputMultiDim(self):
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
- logits, targets, losses = self._Inputs(dtype=tf.float32,
- sizes=[2, 2, 2])
- loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
- np_loss = np.array(losses).astype(np.float32)
- tf_loss = loss.eval()
- self.assertAllClose(np_loss, tf_loss, atol=0.001)
+ for dtype in [tf.float32, tf.float16]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
def testGradient(self):
sizes = [4, 2]