diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/array_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/array_ops_test.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index aae6d0a36e..7ec4624310 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1162,6 +1162,27 @@ class InvertPermutationTest(test_util.TensorFlowTestCase): self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1]) +class UnravelIndexTest(test_util.TensorFlowTestCase): + + def testUnravelIndex(self): + with self.test_session(): + for dtype in [dtypes.int32, dtypes.int64]: + indices_1 = constant_op.constant(1621, dtype=dtype) + dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) + + indices_2 = constant_op.constant([1621], dtype=dtype) + dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_2 = array_ops.unravel_index(indices_2, dims_2) + self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]]) + + indices_3 = constant_op.constant([22, 41, 37], dtype=dtype) + dims_3 = constant_op.constant([7, 6], dtype=dtype) + out_3 = array_ops.unravel_index(indices_3, dims_3) + self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + + class GuaranteeConstOpTest(test_util.TensorFlowTestCase): def testSimple(self): |