diff options
Diffstat (limited to 'tensorflow/core/kernels/reverse_op.cc')
-rw-r--r-- | tensorflow/core/kernels/reverse_op.cc | 60 |
1 files changed, 47 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 7ac34d1c62..8f82784d93 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -182,9 +182,9 @@ class ReverseOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); -#define HANDLE_REVERSE(NDIMS) \ - case NDIMS: \ - HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \ +#define HANDLE_REVERSE(NDIMS) \ + case NDIMS: \ + HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \ return; switch (input_dims) { @@ -228,7 +228,7 @@ void HandleReverseV2Case(OpKernelContext* context, result->tensor<T, NDIMS>()); } -template <typename Device, typename T> +template <typename Device, typename T, typename Tidx> class ReverseV2Op : public OpKernel { public: explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {} @@ -242,15 +242,15 @@ class ReverseV2Op : public OpKernel { } else { const int input_dims = input.dims(); const TensorShape& sparse_dims_shape = sparse_dims.shape(); - const auto& axes_sparse_flat = sparse_dims.flat<int32>(); + const auto& axes_sparse_flat = sparse_dims.flat<Tidx>(); OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape), errors::InvalidArgument("'dims' must be 1-dimension, not ", sparse_dims.dims())); gtl::InlinedVector<bool, 8> axes_dense(input_dims, false); for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) { - int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy)); - int32 canonical_axis = axis < 0 ? input_dims + axis : axis; + Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy)); + Tidx canonical_axis = axis < 0 ? input_dims + axis : axis; OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims, errors::InvalidArgument("'axis'[", dummy, "] = ", axis, " is out of valid range [", 0, ", ", @@ -306,7 +306,13 @@ class ReverseV2Op : public OpKernel { .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<CPUDevice, T>) + ReverseV2Op<CPUDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<CPUDevice, T, int64>) TF_CALL_POD_TYPES(REGISTER_KERNELS); TF_CALL_string(REGISTER_KERNELS); #undef REGISTER_KERNELS @@ -358,7 +364,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<GPUDevice, T>) + ReverseV2Op<GPUDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<GPUDevice, T, int64>) TF_CALL_uint8(REGISTER_GPU_KERNELS); TF_CALL_int8(REGISTER_GPU_KERNELS); // TODO decide whether we want to enable the bool kernel. @@ -387,7 +399,15 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .HostMemory("tensor") .HostMemory("axis") .HostMemory("output"), - ReverseV2Op<CPUDevice, int32>); + ReverseV2Op<CPUDevice, int32, int32>); +REGISTER_KERNEL_BUILDER(Name("ReverseV2") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tidx") + .HostMemory("tensor") + .HostMemory("axis") + .HostMemory("output"), + ReverseV2Op<CPUDevice, int32, int64>); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL @@ -402,7 +422,13 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tidx") \ .HostMemory("axis"), \ - ReverseV2Op<SYCLDevice, T>) + ReverseV2Op<SYCLDevice, T, int32>) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op<SYCLDevice, T, int64>) TF_CALL_uint8(REGISTER_SYCL_KERNELS); TF_CALL_int8(REGISTER_SYCL_KERNELS); TF_CALL_float(REGISTER_SYCL_KERNELS); @@ -422,6 +448,14 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .HostMemory("tensor") .HostMemory("axis") .HostMemory("output"), - ReverseV2Op<CPUDevice, int32>); -#endif // TENSORFLOW_USE_SYCL + ReverseV2Op<CPUDevice, int32, int32>); +REGISTER_KERNEL_BUILDER(Name("ReverseV2") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("Tidx") + .HostMemory("tensor") + .HostMemory("axis") + .HostMemory("output"), + ReverseV2Op<CPUDevice, int32, int64>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |