diff options
-rw-r--r-- | tensorflow/core/kernels/listdiff_op.cc | 16 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/listdiff_op_test.py | 20 |
2 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc index d303bdd560..d28a2729d4 100644 --- a/tensorflow/core/kernels/listdiff_op.cc +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -24,12 +24,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -template <typename T> +template <typename T, typename Tidx> class ListDiffOp : public OpKernel { public: explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) { const DataType dt = DataTypeToEnum<T>::v(); - OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32})); + const DataType dtidx = DataTypeToEnum<Tidx>::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx})); } void Compute(OpKernelContext* context) override { @@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel { Tensor* indices = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices)); - auto Tindices = indices->vec<int32>(); + auto Tindices = indices->vec<Tidx>(); - for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) { + for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) { if (y_set.count(Tx(i)) == 0) { OP_REQUIRES(context, p < out_size, errors::InvalidArgument( @@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("out_idx"), \ - ListDiffOp<type>) + ListDiffOp<type, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ListDiff") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_idx"), \ + ListDiffOp<type, int64>) TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF); REGISTER_LISTDIFF(string); diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 4f053d2a21..ee86cf0b24 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase): y = [compat.as_bytes(str(a)) for a in y] out = [compat.as_bytes(str(a)) for a in out] for diff_func in [array_ops.setdiff1d]: - with self.test_session() as sess: - x_tensor = ops.convert_to_tensor(x, dtype=dtype) - y_tensor = ops.convert_to_tensor(y, dtype=dtype) - out_tensor, idx_tensor = diff_func(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = diff_func(x_tensor, y_tensor, + index_dtype=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] |