aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/listdiff_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/listdiff_op.cc')
-rw-r--r--tensorflow/core/kernels/listdiff_op.cc16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
index d303bdd560..d28a2729d4 100644
--- a/tensorflow/core/kernels/listdiff_op.cc
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-template <typename T>
+template <typename T, typename Tidx>
class ListDiffOp : public OpKernel {
public:
explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
const DataType dt = DataTypeToEnum<T>::v();
- OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
+ const DataType dtidx = DataTypeToEnum<Tidx>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx}));
}
void Compute(OpKernelContext* context) override {
@@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel {
Tensor* indices = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
- auto Tindices = indices->vec<int32>();
+ auto Tindices = indices->vec<Tidx>();
- for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
+ for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) {
if (y_set.count(Tx(i)) == 0) {
OP_REQUIRES(context, p < out_size,
errors::InvalidArgument(
@@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
- ListDiffOp<type>)
+ ListDiffOp<type, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ListDiff") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ ListDiffOp<type, int64>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);