From b1b5c2e74d9a3e54d2e84279db94027060e20609 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 1 Dec 2017 11:11:11 +0000 Subject: Add test cases for int64 support of unravel_index Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/array_ops_test.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 2716c4a51f..68b7c3a98a 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1118,16 +1118,21 @@ class UnravelIndexTest(test_util.TensorFlowTestCase): def testUnravelIndex(self): with self.test_session(): - out_1 = array_ops.unravel_index(1621, [6, 7, 8, 9]) - self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) - out_2 = array_ops.unravel_index([1621], [6, 7, 8, 9]) - self.assertAllEqual(out_2.eval(), [[3], - [1], - [4], - [1]]) - out_3 = array_ops.unravel_index([22, 41, 37], [7, 6]) - self.assertAllEqual(out_3.eval(), [[3, 6, 6], - [4, 5, 1]]) + for dtype in [dtypes.int32, dtypes.int64]: + indices_1 = constant_op.constant(1621, dtype=dtype) + dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) + + indices_2 = constant_op.constant([1621], dtype=dtype) + dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_2 = array_ops.unravel_index(indices_2, dims_2) + self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]]) + + indices_3 = constant_op.constant([22, 41, 37], dtype=dtype) + dims_3 = constant_op.constant([7, 6], dtype=dtype) + out_3 = array_ops.unravel_index(indices_3, dims_3) + self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) class GuaranteeConstOpTest(test_util.TensorFlowTestCase): -- cgit v1.2.3