aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reverse_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/reverse_op.cc')
-rw-r--r--tensorflow/core/kernels/reverse_op.cc60
1 files changed, 47 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc
index 7ac34d1c62..8f82784d93 100644
--- a/tensorflow/core/kernels/reverse_op.cc
+++ b/tensorflow/core/kernels/reverse_op.cc
@@ -182,9 +182,9 @@ class ReverseOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
-#define HANDLE_REVERSE(NDIMS) \
- case NDIMS: \
- HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
+#define HANDLE_REVERSE(NDIMS) \
+ case NDIMS: \
+ HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
return;
switch (input_dims) {
@@ -228,7 +228,7 @@ void HandleReverseV2Case(OpKernelContext* context,
result->tensor<T, NDIMS>());
}
-template <typename Device, typename T>
+template <typename Device, typename T, typename Tidx>
class ReverseV2Op : public OpKernel {
public:
explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
@@ -242,15 +242,15 @@ class ReverseV2Op : public OpKernel {
} else {
const int input_dims = input.dims();
const TensorShape& sparse_dims_shape = sparse_dims.shape();
- const auto& axes_sparse_flat = sparse_dims.flat<int32>();
+ const auto& axes_sparse_flat = sparse_dims.flat<Tidx>();
OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
errors::InvalidArgument("'dims' must be 1-dimension, not ",
sparse_dims.dims()));
gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
- int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy));
- int32 canonical_axis = axis < 0 ? input_dims + axis : axis;
+ Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy));
+ Tidx canonical_axis = axis < 0 ? input_dims + axis : axis;
OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
" is out of valid range [", 0, ", ",
@@ -306,7 +306,13 @@ class ReverseV2Op : public OpKernel {
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<CPUDevice, T>)
+ ReverseV2Op<CPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<CPUDevice, T, int64>)
TF_CALL_POD_TYPES(REGISTER_KERNELS);
TF_CALL_string(REGISTER_KERNELS);
#undef REGISTER_KERNELS
@@ -358,7 +364,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<GPUDevice, T>)
+ ReverseV2Op<GPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<GPUDevice, T, int64>)
TF_CALL_uint8(REGISTER_GPU_KERNELS);
TF_CALL_int8(REGISTER_GPU_KERNELS);
// TODO decide whether we want to enable the bool kernel.
@@ -387,7 +399,15 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
@@ -402,7 +422,13 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<SYCLDevice, T>)
+ ReverseV2Op<SYCLDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<SYCLDevice, T, int64>)
TF_CALL_uint8(REGISTER_SYCL_KERNELS);
TF_CALL_int8(REGISTER_SYCL_KERNELS);
TF_CALL_float(REGISTER_SYCL_KERNELS);
@@ -422,6 +448,14 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
-#endif // TENSORFLOW_USE_SYCL
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow