diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/xent_op_test.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 43be08f8a1..c6c7c4e26c 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -240,6 +240,16 @@ class XentTest(test.TestCase): self._testXentWrapper(features, labels, dim=-1, use_gpu=False) self._testXentWrapper(features, labels, dim=-1, use_gpu=True) + def testZeroDimension(self): + features = np.zeros([0, 2, 4]).astype(np.float32) + labels = np.zeros([0, 2, 4]).astype(np.float32) + np_loss, _ = self._npXent(features, labels) + with self.test_session(use_gpu=True) as sess: + loss = nn_ops.softmax_cross_entropy_with_logits( + labels=labels, logits=features) + tf_loss = sess.run(loss) + self.assertAllEqual(np_loss, tf_loss) + if __name__ == "__main__": test.main() |