aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 026bbbead1..918ebf05ec 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -34,17 +34,16 @@ class ListDiffTest(tf.test.TestCase):
x = [tf.compat.as_bytes(str(a)) for a in x]
y = [tf.compat.as_bytes(str(a)) for a in y]
out = [tf.compat.as_bytes(str(a)) for a in out]
-
- with self.test_session() as sess:
- x_tensor = tf.convert_to_tensor(x, dtype=dtype)
- y_tensor = tf.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
-
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for diff_func in [tf.listdiff, tf.setdiff1d]:
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x, dtype=dtype)
+ y_tensor = tf.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]