diff options
author | 2016-10-31 10:28:47 -0800 | |
---|---|---|
committer | 2016-10-31 11:36:03 -0700 | |
commit | 9ea69fcc18738d3e367b77e1ee58e9aa451eb888 (patch) | |
tree | 6f5e1f31c37e64770587b983c5beef9168209688 /tensorflow/python/kernel_tests/listdiff_op_test.py | |
parent | 3acf1df982b482c3879a84ec2f31f390537ba56a (diff) |
NumPy parity work for 1.0
- tf.where now can optionally take 3 arguments to create a ternary select
behavior. tf.select will be eventually deprecated
- tf.setdiff1d added in preparation to deprecate tf.listdiff
- tf.pad now accepts lower or upper case constants (consatnts are now
case insensitive)
- added tf.multiply, tf.subtract, and tf.negative to match NumPy names
- tf.mul, tf.sub, tf.neg will be deprecated.
Change: 137726562
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/listdiff_op_test.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 026bbbead1..918ebf05ec 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -34,17 +34,16 @@ class ListDiffTest(tf.test.TestCase): x = [tf.compat.as_bytes(str(a)) for a in x] y = [tf.compat.as_bytes(str(a)) for a in y] out = [tf.compat.as_bytes(str(a)) for a in out] - - with self.test_session() as sess: - x_tensor = tf.convert_to_tensor(x, dtype=dtype) - y_tensor = tf.convert_to_tensor(y, dtype=dtype) - out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + for diff_func in [tf.listdiff, tf.setdiff1d]: + with self.test_session() as sess: + x_tensor = tf.convert_to_tensor(x, dtype=dtype) + y_tensor = tf.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = diff_func(x_tensor, y_tensor) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] |