aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-19 20:34:53 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-19 20:34:53 -0700
commit513f7df42e4eadfcd241a3be695af6fd426b734e (patch)
tree032f0e5ff3e78c16a625fa37dff361ba0c59d663 /tensorflow/python/kernel_tests/listdiff_op_test.py
parent7a1ddf26aed9166af69a560e644abd3f0d4f8ecf (diff)
Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` (#13839)
* Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` This fix tries to add `int64` `out_idx` support for `listdiff`/`list_diff`/`setdiff1d`. As was specified in docs (`tf.setdiff1d.__doc__`), it is possible to specify `tf.int32` or `tf.int64` for the type of the output idx. However, the `tf.int64` kernel has not been registered. As a consequence, an error will be thrown out if `tf.int64` is used. This fix adds `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add template for signature matching of ListDiff kernel. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for `int64` out_idx support for `tf.listdiff`/`setdiff1d` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for int32 (missed in the last commit) Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 4f053d2a21..ee86cf0b24 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase):
y = [compat.as_bytes(str(a)) for a in y]
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
- with self.test_session() as sess:
- x_tensor = ops.convert_to_tensor(x, dtype=dtype)
- y_tensor = ops.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session() as sess:
+ x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+ y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
+ index_dtype=index_dtype)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]