diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2017-10-19 20:34:53 -0700 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2017-10-19 20:34:53 -0700 |
commit | 513f7df42e4eadfcd241a3be695af6fd426b734e (patch) | |
tree | 032f0e5ff3e78c16a625fa37dff361ba0c59d663 | |
parent | 7a1ddf26aed9166af69a560e644abd3f0d4f8ecf (diff) |
Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` (#13839)
* Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`
This fix tries to add `int64` `out_idx` support for `listdiff`/`list_diff`/`setdiff1d`.
As was specified in docs (`tf.setdiff1d.__doc__`), it is possible to specify
`tf.int32` or `tf.int64` for the type of the output idx. However,
the `tf.int64` kernel has not been registered. As a consequence,
an error will be thrown out if `tf.int64` is used.
This fix adds `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add template for signature matching of ListDiff kernel.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for `int64` out_idx support for `tf.listdiff`/`setdiff1d`
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test case for int32 (missed in the last commit)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r-- | tensorflow/core/kernels/listdiff_op.cc | 16 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/listdiff_op_test.py | 20 |
2 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc index d303bdd560..d28a2729d4 100644 --- a/tensorflow/core/kernels/listdiff_op.cc +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -24,12 +24,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -template <typename T> +template <typename T, typename Tidx> class ListDiffOp : public OpKernel { public: explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) { const DataType dt = DataTypeToEnum<T>::v(); - OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32})); + const DataType dtidx = DataTypeToEnum<Tidx>::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx})); } void Compute(OpKernelContext* context) override { @@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel { Tensor* indices = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices)); - auto Tindices = indices->vec<int32>(); + auto Tindices = indices->vec<Tidx>(); - for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) { + for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) { if (y_set.count(Tx(i)) == 0) { OP_REQUIRES(context, p < out_size, errors::InvalidArgument( @@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("out_idx"), \ - ListDiffOp<type>) + ListDiffOp<type, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ListDiff") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_idx"), \ + ListDiffOp<type, int64>) TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF); REGISTER_LISTDIFF(string); diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 4f053d2a21..ee86cf0b24 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase): y = [compat.as_bytes(str(a)) for a in y] out = [compat.as_bytes(str(a)) for a in out] for diff_func in [array_ops.setdiff1d]: - with self.test_session() as sess: - x_tensor = ops.convert_to_tensor(x, dtype=dtype) - y_tensor = ops.convert_to_tensor(y, dtype=dtype) - out_tensor, idx_tensor = diff_func(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = diff_func(x_tensor, y_tensor, + index_dtype=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] |