diff options
author | 2016-01-15 18:07:08 -0800 | |
---|---|---|
committer | 2016-01-15 18:25:03 -0800 | |
commit | e75bd059e5f4c76e2e0fd7d5220962e99cd1d960 (patch) | |
tree | 03f0e0576b542055a27623b94ba63127e0ba839f /tensorflow/python/kernel_tests/unique_op_test.py | |
parent | 0cc8848a52ac4131db0b3884bc52c06f0039969c (diff) |
Add the UniqueWithCounts op. This is the same as the Unique op but has an additional output that returns the number of duplicate values.
Change: 112301474
Diffstat (limited to 'tensorflow/python/kernel_tests/unique_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/unique_op_test.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py index e576565b5b..c0b1a3ebbb 100644 --- a/tensorflow/python/kernel_tests/unique_op_test.py +++ b/tensorflow/python/kernel_tests/unique_op_test.py @@ -27,7 +27,7 @@ import tensorflow as tf class UniqueTest(tf.test.TestCase): def testInt32(self): - x = list(np.random.randint(2, high=10, size=7000)) + x = np.random.randint(2, high=10, size=7000) with self.test_session() as sess: y, idx = tf.unique(x) tf_y, tf_idx = sess.run([y, idx]) @@ -37,5 +37,22 @@ class UniqueTest(tf.test.TestCase): for i in range(len(x)): self.assertEqual(x[i], tf_y[tf_idx[i]]) + +class UniqueWithCountsTest(tf.test.TestCase): + + def testInt32(self): + x = np.random.randint(2, high=10, size=7000) + with self.test_session() as sess: + y, idx, count = tf.unique_with_counts(x) + tf_y, tf_idx, tf_count = sess.run([y, idx, count]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + for value, count in zip(tf_y, tf_count): + self.assertEqual(count, np.sum(x == value)) + + if __name__ == "__main__": tf.test.main() |