aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/array_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/array_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index aae6d0a36e..7ec4624310 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1162,6 +1162,27 @@ class InvertPermutationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
+class UnravelIndexTest(test_util.TensorFlowTestCase):
+
+ def testUnravelIndex(self):
+ with self.test_session():
+ for dtype in [dtypes.int32, dtypes.int64]:
+ indices_1 = constant_op.constant(1621, dtype=dtype)
+ dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_1 = array_ops.unravel_index(indices_1, dims_1)
+ self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
+
+ indices_2 = constant_op.constant([1621], dtype=dtype)
+ dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_2 = array_ops.unravel_index(indices_2, dims_2)
+ self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]])
+
+ indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
+ dims_3 = constant_op.constant([7, 6], dtype=dtype)
+ out_3 = array_ops.unravel_index(indices_3, dims_3)
+ self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
+
+
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
def testSimple(self):