aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-12-01 11:11:11 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-01-29 17:10:43 +0000
commitb1b5c2e74d9a3e54d2e84279db94027060e20609 (patch)
tree8ac1845b10740849a5227cee12e77af671521dfb
parentbabc9bae71d31e35b8d66a715fcd527ee1ee645a (diff)
Add test cases for int64 support of unravel_index
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py25
1 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 2716c4a51f..68b7c3a98a 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1118,16 +1118,21 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
def testUnravelIndex(self):
with self.test_session():
- out_1 = array_ops.unravel_index(1621, [6, 7, 8, 9])
- self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
- out_2 = array_ops.unravel_index([1621], [6, 7, 8, 9])
- self.assertAllEqual(out_2.eval(), [[3],
- [1],
- [4],
- [1]])
- out_3 = array_ops.unravel_index([22, 41, 37], [7, 6])
- self.assertAllEqual(out_3.eval(), [[3, 6, 6],
- [4, 5, 1]])
+ for dtype in [dtypes.int32, dtypes.int64]:
+ indices_1 = constant_op.constant(1621, dtype=dtype)
+ dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_1 = array_ops.unravel_index(indices_1, dims_1)
+ self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
+
+ indices_2 = constant_op.constant([1621], dtype=dtype)
+ dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_2 = array_ops.unravel_index(indices_2, dims_2)
+ self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]])
+
+ indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
+ dims_3 = constant_op.constant([7, 6], dtype=dtype)
+ out_3 = array_ops.unravel_index(indices_3, dims_3)
+ self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):