diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2017-10-19 20:34:53 -0700 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2017-10-19 20:34:53 -0700 |
commit | 513f7df42e4eadfcd241a3be695af6fd426b734e (patch) | |
tree | 032f0e5ff3e78c16a625fa37dff361ba0c59d663 /tensorflow/python/kernel_tests/listdiff_op_test.py | |
parent | 7a1ddf26aed9166af69a560e644abd3f0d4f8ecf (diff) |
Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` (#13839)
* Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`
This fix tries to add `int64` `out_idx` support for `listdiff`/`list_diff`/`setdiff1d`.
As was specified in docs (`tf.setdiff1d.__doc__`), it is possible to specify
`tf.int32` or `tf.int64` for the type of the output idx. However,
the `tf.int64` kernel has not been registered. As a consequence,
an error will be thrown out if `tf.int64` is used.
This fix adds `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add template for signature matching of ListDiff kernel.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for `int64` out_idx support for `tf.listdiff`/`setdiff1d`
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test case for int32 (missed in the last commit)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/listdiff_op_test.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 4f053d2a21..ee86cf0b24 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase): y = [compat.as_bytes(str(a)) for a in y] out = [compat.as_bytes(str(a)) for a in out] for diff_func in [array_ops.setdiff1d]: - with self.test_session() as sess: - x_tensor = ops.convert_to_tensor(x, dtype=dtype) - y_tensor = ops.convert_to_tensor(y, dtype=dtype) - out_tensor, idx_tensor = diff_func(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = diff_func(x_tensor, y_tensor, + index_dtype=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] |