aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/listdiff_op.cc16
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py20
2 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
index d303bdd560..d28a2729d4 100644
--- a/tensorflow/core/kernels/listdiff_op.cc
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-template <typename T>
+template <typename T, typename Tidx>
class ListDiffOp : public OpKernel {
public:
explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
const DataType dt = DataTypeToEnum<T>::v();
- OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
+ const DataType dtidx = DataTypeToEnum<Tidx>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx}));
}
void Compute(OpKernelContext* context) override {
@@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel {
Tensor* indices = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
- auto Tindices = indices->vec<int32>();
+ auto Tindices = indices->vec<Tidx>();
- for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
+ for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) {
if (y_set.count(Tx(i)) == 0) {
OP_REQUIRES(context, p < out_size,
errors::InvalidArgument(
@@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
- ListDiffOp<type>)
+ ListDiffOp<type, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ListDiff") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ ListDiffOp<type, int64>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 4f053d2a21..ee86cf0b24 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase):
y = [compat.as_bytes(str(a)) for a in y]
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
- with self.test_session() as sess:
- x_tensor = ops.convert_to_tensor(x, dtype=dtype)
- y_tensor = ops.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session() as sess:
+ x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+ y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
+ index_dtype=index_dtype)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]