aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2016-10-31 10:28:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 11:36:03 -0700
commit9ea69fcc18738d3e367b77e1ee58e9aa451eb888 (patch)
tree6f5e1f31c37e64770587b983c5beef9168209688 /tensorflow/python/kernel_tests/listdiff_op_test.py
parent3acf1df982b482c3879a84ec2f31f390537ba56a (diff)
NumPy parity work for 1.0
- tf.where now can optionally take 3 arguments to create a ternary select behavior. tf.select will be eventually deprecated - tf.setdiff1d added in preparation to deprecate tf.listdiff - tf.pad now accepts lower or upper case constants (consatnts are now case insensitive) - added tf.multiply, tf.subtract, and tf.negative to match NumPy names - tf.mul, tf.sub, tf.neg will be deprecated. Change: 137726562
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 026bbbead1..918ebf05ec 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -34,17 +34,16 @@ class ListDiffTest(tf.test.TestCase):
x = [tf.compat.as_bytes(str(a)) for a in x]
y = [tf.compat.as_bytes(str(a)) for a in y]
out = [tf.compat.as_bytes(str(a)) for a in out]
-
- with self.test_session() as sess:
- x_tensor = tf.convert_to_tensor(x, dtype=dtype)
- y_tensor = tf.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
-
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for diff_func in [tf.listdiff, tf.setdiff1d]:
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x, dtype=dtype)
+ y_tensor = tf.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]