aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py117
1 files changed, 117 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
new file mode 100644
index 0000000000..b4607be1fb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -0,0 +1,117 @@
+"""Tests for tensorflow.kernels.listdiff_op."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class ListDiffTest(tf.test.TestCase):
+
+ def _testListDiff(self, x, y, out, idx, dtype=np.int32):
+ x = np.array(x, dtype=dtype)
+ y = np.array(y, dtype=dtype)
+ out = np.array(out, dtype=dtype)
+ idx = np.array(idx, dtype=dtype)
+
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x)
+ y_tensor = tf.convert_to_tensor(y)
+ out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
+
+ def testBasic1(self):
+ x = [1, 2, 3, 4]
+ y = [1, 2]
+ out = [3, 4]
+ idx = [2, 3]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testBasic2(self):
+ x = [1, 2, 3, 4]
+ y = [2]
+ out = [1, 3, 4]
+ idx = [0, 2, 3]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testBasic3(self):
+ x = [1, 4, 3, 2]
+ y = [4, 2]
+ out = [1, 3]
+ idx = [0, 2]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testDuplicates(self):
+ x = [1, 2, 4, 3, 2, 3, 3, 1]
+ y = [4, 2]
+ out = [1, 3, 3, 3, 1]
+ idx = [0, 3, 5, 6, 7]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testRandom(self):
+ num_random_tests = 10
+ int_low = -7
+ int_high = 8
+ max_size = 50
+ for _ in xrange(num_random_tests):
+ x_size = np.random.randint(max_size + 1)
+ x = np.random.randint(int_low, int_high, size=x_size)
+ y_size = np.random.randint(max_size + 1)
+ y = np.random.randint(int_low, int_high, size=y_size)
+ out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
+ if out_idx:
+ out_idx = map(list, zip(*out_idx))
+ out = out_idx[0]
+ idx = out_idx[1]
+ else:
+ out = []
+ idx = []
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testInt32FullyOverlapping(self):
+ x = [1, 2, 3, 4]
+ y = [1, 2, 3, 4]
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32NonOverlapping(self):
+ x = [1, 2, 3, 4]
+ y = [5, 6]
+ out = x
+ idx = range(len(x))
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyX(self):
+ x = []
+ y = [1, 2]
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyY(self):
+ x = [1, 2, 3, 4]
+ y = []
+ out = x
+ idx = range(len(x))
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyXY(self):
+ x = []
+ y = []
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+if __name__ == "__main__":
+ tf.test.main()