aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-19 20:34:53 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-19 20:34:53 -0700
commit513f7df42e4eadfcd241a3be695af6fd426b734e (patch)
tree032f0e5ff3e78c16a625fa37dff361ba0c59d663
parent7a1ddf26aed9166af69a560e644abd3f0d4f8ecf (diff)
Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` (#13839)
* Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` This fix tries to add `int64` `out_idx` support for `listdiff`/`list_diff`/`setdiff1d`. As was specified in docs (`tf.setdiff1d.__doc__`), it is possible to specify `tf.int32` or `tf.int64` for the type of the output idx. However, the `tf.int64` kernel has not been registered. As a consequence, an error will be thrown out if `tf.int64` is used. This fix adds `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add template for signature matching of ListDiff kernel. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for `int64` out_idx support for `tf.listdiff`/`setdiff1d` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for int32 (missed in the last commit) Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r--tensorflow/core/kernels/listdiff_op.cc16
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py20
2 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
index d303bdd560..d28a2729d4 100644
--- a/tensorflow/core/kernels/listdiff_op.cc
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-template <typename T>
+template <typename T, typename Tidx>
class ListDiffOp : public OpKernel {
public:
explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
const DataType dt = DataTypeToEnum<T>::v();
- OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
+ const DataType dtidx = DataTypeToEnum<Tidx>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx}));
}
void Compute(OpKernelContext* context) override {
@@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel {
Tensor* indices = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
- auto Tindices = indices->vec<int32>();
+ auto Tindices = indices->vec<Tidx>();
- for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
+ for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) {
if (y_set.count(Tx(i)) == 0) {
OP_REQUIRES(context, p < out_size,
errors::InvalidArgument(
@@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
- ListDiffOp<type>)
+ ListDiffOp<type, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ListDiff") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_idx"), \
+ ListDiffOp<type, int64>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 4f053d2a21..ee86cf0b24 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -41,15 +41,17 @@ class ListDiffTest(test.TestCase):
y = [compat.as_bytes(str(a)) for a in y]
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
- with self.test_session() as sess:
- x_tensor = ops.convert_to_tensor(x, dtype=dtype)
- y_tensor = ops.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session() as sess:
+ x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+ y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
+ index_dtype=index_dtype)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]