aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/reverse_op.cc60
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py36
2 files changed, 69 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc
index 7ac34d1c62..8f82784d93 100644
--- a/tensorflow/core/kernels/reverse_op.cc
+++ b/tensorflow/core/kernels/reverse_op.cc
@@ -182,9 +182,9 @@ class ReverseOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
-#define HANDLE_REVERSE(NDIMS) \
- case NDIMS: \
- HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
+#define HANDLE_REVERSE(NDIMS) \
+ case NDIMS: \
+ HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
return;
switch (input_dims) {
@@ -228,7 +228,7 @@ void HandleReverseV2Case(OpKernelContext* context,
result->tensor<T, NDIMS>());
}
-template <typename Device, typename T>
+template <typename Device, typename T, typename Tidx>
class ReverseV2Op : public OpKernel {
public:
explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
@@ -242,15 +242,15 @@ class ReverseV2Op : public OpKernel {
} else {
const int input_dims = input.dims();
const TensorShape& sparse_dims_shape = sparse_dims.shape();
- const auto& axes_sparse_flat = sparse_dims.flat<int32>();
+ const auto& axes_sparse_flat = sparse_dims.flat<Tidx>();
OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
errors::InvalidArgument("'dims' must be 1-dimension, not ",
sparse_dims.dims()));
gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
- int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy));
- int32 canonical_axis = axis < 0 ? input_dims + axis : axis;
+ Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy));
+ Tidx canonical_axis = axis < 0 ? input_dims + axis : axis;
OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
" is out of valid range [", 0, ", ",
@@ -306,7 +306,13 @@ class ReverseV2Op : public OpKernel {
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<CPUDevice, T>)
+ ReverseV2Op<CPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<CPUDevice, T, int64>)
TF_CALL_POD_TYPES(REGISTER_KERNELS);
TF_CALL_string(REGISTER_KERNELS);
#undef REGISTER_KERNELS
@@ -358,7 +364,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<GPUDevice, T>)
+ ReverseV2Op<GPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<GPUDevice, T, int64>)
TF_CALL_uint8(REGISTER_GPU_KERNELS);
TF_CALL_int8(REGISTER_GPU_KERNELS);
// TODO decide whether we want to enable the bool kernel.
@@ -387,7 +399,15 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
@@ -402,7 +422,13 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<SYCLDevice, T>)
+ ReverseV2Op<SYCLDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<SYCLDevice, T, int64>)
TF_CALL_uint8(REGISTER_SYCL_KERNELS);
TF_CALL_int8(REGISTER_SYCL_KERNELS);
TF_CALL_float(REGISTER_SYCL_KERNELS);
@@ -422,6 +448,14 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
-#endif // TENSORFLOW_USE_SYCL
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 17492e9255..1dbe7deb97 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -277,26 +277,34 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype)
for use_gpu in [False, True]:
- with self.test_session(use_gpu=use_gpu):
- x_tf = array_ops.reverse_v2(x_np, [0]).eval()
- self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
+ for axis_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = array_ops.reverse_v2(x_np,
+ constant_op.constant([0], dtype=axis_dtype)).eval()
+ self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
def _reverse2DimAuto(self, np_dtype):
x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype)
for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
for use_gpu in [False, True]:
- with self.test_session(use_gpu=use_gpu):
- x_tf_1 = reverse_f(x_np, [0]).eval()
- x_tf_2 = reverse_f(x_np, [-2]).eval()
- x_tf_3 = reverse_f(x_np, [1]).eval()
- x_tf_4 = reverse_f(x_np, [-1]).eval()
- x_tf_5 = reverse_f(x_np, [1, 0]).eval()
- self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
- self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
- self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
- self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
- self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
+ for axis_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf_1 = reverse_f(x_np,
+ constant_op.constant([0], dtype=axis_dtype)).eval()
+ x_tf_2 = reverse_f(x_np,
+ constant_op.constant([-2], dtype=axis_dtype)).eval()
+ x_tf_3 = reverse_f(x_np,
+ constant_op.constant([1], dtype=axis_dtype)).eval()
+ x_tf_4 = reverse_f(x_np,
+ constant_op.constant([-1], dtype=axis_dtype)).eval()
+ x_tf_5 = reverse_f(x_np,
+ constant_op.constant([1, 0], dtype=axis_dtype)).eval()
+ self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
+ self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
+ self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
+ self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
+ self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
# This is the version of reverse that uses axis indices rather than
# bool tensors