diff options
Diffstat (limited to 'tensorflow/core/kernels/listdiff_op.cc')
-rw-r--r-- | tensorflow/core/kernels/listdiff_op.cc | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc index d303bdd560..d28a2729d4 100644 --- a/tensorflow/core/kernels/listdiff_op.cc +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -24,12 +24,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -template <typename T> +template <typename T, typename Tidx> class ListDiffOp : public OpKernel { public: explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) { const DataType dt = DataTypeToEnum<T>::v(); - OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32})); + const DataType dtidx = DataTypeToEnum<Tidx>::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx})); } void Compute(OpKernelContext* context) override { @@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel { Tensor* indices = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices)); - auto Tindices = indices->vec<int32>(); + auto Tindices = indices->vec<Tidx>(); - for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) { + for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) { if (y_set.count(Tx(i)) == 0) { OP_REQUIRES(context, p < out_size, errors::InvalidArgument( @@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("out_idx"), \ - ListDiffOp<type>) + ListDiffOp<type, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ListDiff") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_idx"), \ + ListDiffOp<type, int64>) TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF); REGISTER_LISTDIFF(string); |