diff options
-rw-r--r-- | tensorflow/core/kernels/reverse_op.cc | 60 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/array_ops_test.py | 36 |
2 files changed, 69 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 7ac34d1c62..8f82784d93 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -182,9 +182,9 @@ class ReverseOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); -#define HANDLE_REVERSE(NDIMS) \ - case NDIMS: \ - HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \ +#define HANDLE_REVERSE(NDIMS) \ + case NDIMS: \ + HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \ return; switch (input_dims) { @@ -228,7 +228,7 @@ void HandleReverseV2Case(OpKernelContext* context, result->tensor<T, NDIMS>()); } -template <typename Device, typename T> +template <typename Device, typename T, typename Tidx> class ReverseV2Op : public OpKernel { public: explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {} @@ -242,15 +242,15 @@ class ReverseV2Op : public OpKernel { } else { const int input_dims = input.dims(); const TensorShape& sparse_dims_shape = sparse_dims.shape(); - const auto& axes_sparse_flat = sparse_dims.flat<int32>(); + const auto& axes_sparse_flat = sparse_dims.flat<Tidx>(); OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape), errors::InvalidArgument("'dims' must be 1-dimension, not ", sparse_dims.dims())); gtl::InlinedVector<bool, 8> axes_dense(input_dims, false); for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) { - int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy)); - int32 canonical_axis = axis < 0 ? input_dims + axis : axis; + Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy)); + Tidx canonical_axis = axis < 0 ? input_dims + axis : axis; OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims, errors::InvalidArgument("'axis'[", dummy, "] = ", axis, " is out of valid range [", 0, ", ", @@ -306,7 +306,13 @@ class ReverseV2Op : public OpKernel { .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<CPUDevice, T>) + ReverseV2Op<CPUDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<CPUDevice, T, int64>) TF_CALL_POD_TYPES(REGISTER_KERNELS); TF_CALL_string(REGISTER_KERNELS); #undef REGISTER_KERNELS @@ -358,7 +364,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<GPUDevice, T>) + ReverseV2Op<GPUDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<GPUDevice, T, int64>) TF_CALL_uint8(REGISTER_GPU_KERNELS); TF_CALL_int8(REGISTER_GPU_KERNELS); // TODO decide whether we want to enable the bool kernel. @@ -387,7 +399,15 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .HostMemory("tensor") .HostMemory("axis") .HostMemory("output"), - ReverseV2Op<CPUDevice, int32>); + ReverseV2Op<CPUDevice, int32, int32>); +REGISTER_KERNEL_BUILDER(Name("ReverseV2") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tidx") + .HostMemory("tensor") + .HostMemory("axis") + .HostMemory("output"), + ReverseV2Op<CPUDevice, int32, int64>); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL @@ -402,7 +422,13 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<SYCLDevice, T>) + ReverseV2Op<SYCLDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<SYCLDevice, T, int64>) TF_CALL_uint8(REGISTER_SYCL_KERNELS); TF_CALL_int8(REGISTER_SYCL_KERNELS); TF_CALL_float(REGISTER_SYCL_KERNELS); @@ -422,6 +448,14 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .HostMemory("tensor") .HostMemory("axis") .HostMemory("output"), - ReverseV2Op<CPUDevice, int32>); -#endif // TENSORFLOW_USE_SYCL + ReverseV2Op<CPUDevice, int32, int32>); +REGISTER_KERNEL_BUILDER(Name("ReverseV2") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tidx") + .HostMemory("tensor") + .HostMemory("axis") + .HostMemory("output"), + ReverseV2Op<CPUDevice, int32, int64>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 17492e9255..1dbe7deb97 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -277,26 +277,34 @@ class ReverseV2Test(test_util.TensorFlowTestCase): x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype) for use_gpu in [False, True]: - with self.test_session(use_gpu=use_gpu): - x_tf = array_ops.reverse_v2(x_np, [0]).eval() - self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) + for axis_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session(use_gpu=use_gpu): + x_tf = array_ops.reverse_v2(x_np, + constant_op.constant([0], dtype=axis_dtype)).eval() + self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) def _reverse2DimAuto(self, np_dtype): x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype) for reverse_f in [array_ops.reverse_v2, array_ops.reverse]: for use_gpu in [False, True]: - with self.test_session(use_gpu=use_gpu): - x_tf_1 = reverse_f(x_np, [0]).eval() - x_tf_2 = reverse_f(x_np, [-2]).eval() - x_tf_3 = reverse_f(x_np, [1]).eval() - x_tf_4 = reverse_f(x_np, [-1]).eval() - x_tf_5 = reverse_f(x_np, [1, 0]).eval() - self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :]) - self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :]) - self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1]) - self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1]) - self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1]) + for axis_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session(use_gpu=use_gpu): + x_tf_1 = reverse_f(x_np, + constant_op.constant([0], dtype=axis_dtype)).eval() + x_tf_2 = reverse_f(x_np, + constant_op.constant([-2], dtype=axis_dtype)).eval() + x_tf_3 = reverse_f(x_np, + constant_op.constant([1], dtype=axis_dtype)).eval() + x_tf_4 = reverse_f(x_np, + constant_op.constant([-1], dtype=axis_dtype)).eval() + x_tf_5 = reverse_f(x_np, + constant_op.constant([1, 0], dtype=axis_dtype)).eval() + self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :]) + self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :]) + self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1]) + self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1]) + self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1]) # This is the version of reverse that uses axis indices rather than # bool tensors |