aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/unique_op_test.py
diff options
context:
space:
mode:
authorGravatar Andrew Dai <adai@google.com>2016-01-15 18:07:08 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-15 18:25:03 -0800
commite75bd059e5f4c76e2e0fd7d5220962e99cd1d960 (patch)
tree03f0e0576b542055a27623b94ba63127e0ba839f /tensorflow/python/kernel_tests/unique_op_test.py
parent0cc8848a52ac4131db0b3884bc52c06f0039969c (diff)
Add the UniqueWithCounts op. This is the same as the Unique op but has an additional output that returns the number of duplicate values.
Change: 112301474
Diffstat (limited to 'tensorflow/python/kernel_tests/unique_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index e576565b5b..c0b1a3ebbb 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -27,7 +27,7 @@ import tensorflow as tf
class UniqueTest(tf.test.TestCase):
def testInt32(self):
- x = list(np.random.randint(2, high=10, size=7000))
+ x = np.random.randint(2, high=10, size=7000)
with self.test_session() as sess:
y, idx = tf.unique(x)
tf_y, tf_idx = sess.run([y, idx])
@@ -37,5 +37,22 @@ class UniqueTest(tf.test.TestCase):
for i in range(len(x)):
self.assertEqual(x[i], tf_y[tf_idx[i]])
+
+class UniqueWithCountsTest(tf.test.TestCase):
+
+ def testInt32(self):
+ x = np.random.randint(2, high=10, size=7000)
+ with self.test_session() as sess:
+ y, idx, count = tf.unique_with_counts(x)
+ tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+
+ self.assertEqual(len(x), len(tf_idx))
+ self.assertEqual(len(tf_y), len(np.unique(x)))
+ for i in range(len(x)):
+ self.assertEqual(x[i], tf_y[tf_idx[i]])
+ for value, count in zip(tf_y, tf_count):
+ self.assertEqual(count, np.sum(x == value))
+
+
if __name__ == "__main__":
tf.test.main()