diff options
author | 2016-07-22 12:54:03 -0800 | |
---|---|---|
committer | 2016-07-22 14:04:04 -0700 | |
commit | 075868639f79d77657a7d268fc4762d7a90c370b (patch) | |
tree | 08fcb3e79bde8fa9a0006be04233f8067bd29b1e /tensorflow | |
parent | ef5d941c164b22a9be47e4f5bd7c90ba7c83e984 (diff) |
Use 32bit floats to compute cross entropies. 16 bit floats aren't accurate
enough to deal with more than a few labels
Change: 128208074
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/kernel_tests/sparse_xent_op_test.py | 17 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/xent_op_test.py | 21 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_xent_test.py | 27 |
4 files changed, 56 insertions, 41 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index 93a9a2667e..dbb0bac3f1 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -24,6 +24,7 @@ import time import numpy as np import tensorflow as tf +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops @@ -31,7 +32,6 @@ from tensorflow.python.ops import sparse_ops class SparseXentTest(tf.test.TestCase): def _npXent(self, features, labels): - is_higher_dim = len(features.shape) > 2 features = np.reshape(features, [-1, features.shape[-1]]) labels = np.reshape(labels, [-1]) batch_dim = 0 @@ -44,15 +44,13 @@ class SparseXentTest(tf.test.TestCase): labels_mat[np.arange(batch_size), labels] = 1.0 bp = (probs - labels_mat) l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1) - return l, bp, is_higher_dim + return l, bp def _testXent(self, np_features, np_labels, use_gpu=False): - np_loss, np_backprop, is_higher_dim = self._npXent(np_features, np_labels) + np_loss, np_backprop = self._npXent(np_features, np_labels) with self.test_session(use_gpu=use_gpu) as sess: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( np_features, np_labels) - backprop = (loss.op.inputs[0].op.outputs[1] if is_higher_dim - else loss.op.outputs[1]) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllCloseAccordingToType(np_loss, tf_loss) self.assertAllCloseAccordingToType(np_backprop, tf_backprop) @@ -64,10 +62,9 @@ class SparseXentTest(tf.test.TestCase): def _testSingleClass(self, use_gpu=False): for label_dtype in np.int32, np.int64: with self.test_session(use_gpu=use_gpu) as sess: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( np.array([[1.], [-1.], [0.]]).astype(np.float32), np.array([0, 0, 0]).astype(label_dtype)) - backprop = loss.op.outputs[1] tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllClose([0.0, 0.0, 0.0], tf_loss) self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop) @@ -101,7 +98,7 @@ class SparseXentTest(tf.test.TestCase): # With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644] # The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)] # = [1.3862, 3.4420] - np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels)) + np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75], [-0.968, 0.087, 0.237, 0.6439]]), np_backprop, @@ -169,7 +166,7 @@ class SparseXentTest(tf.test.TestCase): self.assertLess(err, 5e-8) def _testHighDim(self, use_gpu, features, labels): - np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels)) + np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) # manually reshape loss np_loss = np.reshape(np_loss, np.array(labels).shape) with self.test_session(use_gpu=use_gpu) as sess: diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 9e7c563547..70b3bdcbb4 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np import tensorflow as tf +from tensorflow.python.ops import gen_nn_ops + class XentTest(tf.test.TestCase): @@ -38,8 +40,8 @@ class XentTest(tf.test.TestCase): def _testXent(self, np_features, np_labels, use_gpu=False): np_loss, np_backprop = self._npXent(np_features, np_labels) with self.test_session(use_gpu=use_gpu) as sess: - loss = tf.nn.softmax_cross_entropy_with_logits(np_features, np_labels) - backprop = loss.op.outputs[1] + loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits( + np_features, np_labels) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllCloseAccordingToType(np_loss, tf_loss) self.assertAllCloseAccordingToType(np_backprop, tf_backprop) @@ -51,10 +53,9 @@ class XentTest(tf.test.TestCase): def _testSingleClass(self, use_gpu=False): for dtype in np.float16, np.float32: with self.test_session(use_gpu=use_gpu) as sess: - loss = tf.nn.softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits( np.array([[1.], [-1.], [0.]]).astype(dtype), np.array([[-1.], [0.], [1.]]).astype(dtype)) - backprop = loss.op.outputs[1] tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllClose([0.0, 0.0, 0.0], tf_loss) self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop) @@ -69,9 +70,9 @@ class XentTest(tf.test.TestCase): [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(dtype) np_labels = np.array( [[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype) - self.assertRaisesRegexp( - ValueError, "must have rank 2", - tf.nn.softmax_cross_entropy_with_logits, np_features, np_labels) + self.assertRaisesRegexp(ValueError, "must have rank 2", + gen_nn_ops._softmax_cross_entropy_with_logits, + np_features, np_labels) def testNpXent(self): # We create 2 batches of logits for testing. @@ -110,14 +111,14 @@ class XentTest(tf.test.TestCase): def testShapeMismatch(self): with self.test_session(): with self.assertRaises(ValueError): - tf.nn.softmax_cross_entropy_with_logits( + gen_nn_ops._softmax_cross_entropy_with_logits( [[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]]) def testNotMatrix(self): with self.test_session(): with self.assertRaises(ValueError): - tf.nn.softmax_cross_entropy_with_logits([0., 1., 2., 3.], - [0., 1., 0., 1.]) + gen_nn_ops._softmax_cross_entropy_with_logits([0., 1., 2., 3.], + [0., 1., 0., 1.]) def testHalf(self): self._testAll( diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 486ef6efb9..73e51aab7d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -466,7 +466,7 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): output of `softmax`, as it will produce incorrect results. `logits` and `labels` must have the same shape `[batch_size, num_classes]` - and the same dtype (either `float32` or `float64`). + and the same dtype (either `float16`, `float32`, or `float64`). Args: logits: Unscaled log probabilities. @@ -481,11 +481,18 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): # could break users who call this with bad labels, but disregard the bad # results. + logits = ops.convert_to_tensor(logits) + precise_logits = math_ops.cast(logits, dtypes.float32) if ( + logits.dtype == dtypes.float16) else logits + # The second output tensor contains the gradients. We use it in # _CrossEntropyGrad() in nn_grad but not here. cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits( - logits, labels, name=name) - return cost + precise_logits, labels, name=name) + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): @@ -536,6 +543,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): "SparseSoftmaxCrossEntropyWithLogits"): labels = ops.convert_to_tensor(labels) logits = ops.convert_to_tensor(logits) + precise_logits = math_ops.cast(logits, dtypes.float32) if ( + dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits # Store label shape for result later. labels_static_shape = labels.get_shape() @@ -552,20 +561,27 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): # Check if no reshapes are required. if logits.get_shape().ndims == 2: cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - logits, labels, name=name) - return cost + precise_logits, labels, name=name) + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost + # Reshape logits to 2 dim, labels to 1 dim. num_classes = array_ops.gather(array_ops.shape(logits), array_ops.rank(logits) - 1) - logits = array_ops.reshape(logits, [-1, num_classes]) + precise_logits = array_ops.reshape(precise_logits, [-1, num_classes]) labels = array_ops.reshape(labels, [-1]) # The second output tensor contains the gradients. We use it in # _CrossEntropyGrad() in nn_grad but not here. cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - logits, labels, name=name) + precise_logits, labels, name=name) cost = array_ops.reshape(cost, labels_shape) cost.set_shape(labels_static_shape) - return cost + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost @ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits") diff --git a/tensorflow/python/ops/nn_xent_test.py b/tensorflow/python/ops/nn_xent_test.py index c09d6141bb..10c9f8ca87 100644 --- a/tensorflow/python/ops/nn_xent_test.py +++ b/tensorflow/python/ops/nn_xent_test.py @@ -56,22 +56,23 @@ class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase): def testLogisticOutput(self): for use_gpu in [True, False]: - with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=tf.float32) - loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) - np_loss = np.array(losses).astype(np.float32) - tf_loss = loss.eval() - self.assertAllClose(np_loss, tf_loss, atol=0.001) + for dtype in [tf.float32, tf.float16]: + with self.test_session(use_gpu=use_gpu): + logits, targets, losses = self._Inputs(dtype=dtype) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) + np_loss = np.array(losses).astype(np.float32) + tf_loss = loss.eval() + self.assertAllClose(np_loss, tf_loss, atol=0.001) def testLogisticOutputMultiDim(self): for use_gpu in [True, False]: - with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=tf.float32, - sizes=[2, 2, 2]) - loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) - np_loss = np.array(losses).astype(np.float32) - tf_loss = loss.eval() - self.assertAllClose(np_loss, tf_loss, atol=0.001) + for dtype in [tf.float32, tf.float16]: + with self.test_session(use_gpu=use_gpu): + logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2]) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) + np_loss = np.array(losses).astype(np.float32) + tf_loss = loss.eval() + self.assertAllClose(np_loss, tf_loss, atol=0.001) def testGradient(self): sizes = [4, 2] |