diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/listdiff_op_test.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 026bbbead1..918ebf05ec 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -34,17 +34,16 @@ class ListDiffTest(tf.test.TestCase): x = [tf.compat.as_bytes(str(a)) for a in x] y = [tf.compat.as_bytes(str(a)) for a in y] out = [tf.compat.as_bytes(str(a)) for a in out] - - with self.test_session() as sess: - x_tensor = tf.convert_to_tensor(x, dtype=dtype) - y_tensor = tf.convert_to_tensor(y, dtype=dtype) - out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + for diff_func in [tf.listdiff, tf.setdiff1d]: + with self.test_session() as sess: + x_tensor = tf.convert_to_tensor(x, dtype=dtype) + y_tensor = tf.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = diff_func(x_tensor, y_tensor) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] |