aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/xent_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/xent_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 43be08f8a1..c6c7c4e26c 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -240,6 +240,16 @@ class XentTest(test.TestCase):
self._testXentWrapper(features, labels, dim=-1, use_gpu=False)
self._testXentWrapper(features, labels, dim=-1, use_gpu=True)
+ def testZeroDimension(self):
+ features = np.zeros([0, 2, 4]).astype(np.float32)
+ labels = np.zeros([0, 2, 4]).astype(np.float32)
+ np_loss, _ = self._npXent(features, labels)
+ with self.test_session(use_gpu=True) as sess:
+ loss = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=features)
+ tf_loss = sess.run(loss)
+ self.assertAllEqual(np_loss, tf_loss)
+
if __name__ == "__main__":
test.main()