aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/unique_op_test.py
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-01-27 19:58:54 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-02-24 03:20:45 +0000
commita347f14c8aa14e81710c0cb33bf1a0bd23f3bcfd (patch)
tree998a2cfcb6011b35d34808aab07da42a6996dffe /tensorflow/python/kernel_tests/unique_op_test.py
parent812eac93168881c6472fc08b90bdc4a9695b3220 (diff)
Add test cases for unique_with_counts_v2
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/python/kernel_tests/unique_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index 6366d2e181..4498fd9fe9 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -133,6 +133,39 @@ class UniqueWithCountsTest(test.TestCase):
v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
self.assertEqual(count, sum(v))
+ def testInt32Axis(self):
+ for dtype in [np.int32, np.int64]:
+ x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
+ with self.test_session() as sess:
+ y0, idx0, count0 = gen_array_ops._unique_with_counts_v2(
+ x, axis=np.array([0], dtype))
+ tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
+ y1, idx1, count1 = gen_array_ops._unique_with_counts_v2(
+ x, axis=np.array([1], dtype))
+ tf_y1, tf_idx1, tf_count1 = sess.run([y1, idx1, count1])
+ self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
+ self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
+ self.assertAllEqual(tf_count0, np.array([2, 1]))
+ self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
+ self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
+ self.assertAllEqual(tf_count1, np.array([1, 2]))
+
+ def testInt32V2(self):
+ # This test is only temporary, once V2 is used
+ # by default, the axis will be wrapped to allow `axis=None`.
+ x = np.random.randint(2, high=10, size=7000)
+ with self.test_session() as sess:
+ y, idx, count = gen_array_ops._unique_with_counts_v2(
+ x, axis=np.array([], np.int32))
+ tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+
+ self.assertEqual(len(x), len(tf_idx))
+ self.assertEqual(len(tf_y), len(np.unique(x)))
+ for i in range(len(x)):
+ self.assertEqual(x[i], tf_y[tf_idx[i]])
+ for value, count in zip(tf_y, tf_count):
+ self.assertEqual(count, np.sum(x == value))
+
if __name__ == '__main__':
test.main()