diff options
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 20f0edf309..96c051c636 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -31,13 +31,14 @@ limitations under the License. namespace tensorflow { -// inv = InvertPermutationOp(T<int32> p) takes a permutation of +// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of // integers 0, 1, ..., n - 1 and returns the inverted // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). // -// REQUIRES: input is a vector of int32. +// REQUIRES: input is a vector of int32 or int64. // REQUIRES: input is a permutation of 0, 1, ..., n-1. +template <typename T> class InvertPermutationOp : public OpKernel { public: explicit InvertPermutationOp(OpKernelConstruction* context) @@ -48,20 +49,19 @@ class InvertPermutationOp : public OpKernel { OP_REQUIRES( context, TensorShapeUtils::IsVector(input.shape()), errors::InvalidArgument("invert_permutation expects a 1D vector.")); - auto Tin = input.vec<int32>(); + auto Tin = input.vec<T>(); OP_REQUIRES(context, FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - const int32 N = - static_cast<int32>(Tin.size()); // Safe: bounds-checked above. + const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); - auto Tout = output->vec<int32>(); + auto Tout = output->vec<T>(); std::fill_n(Tout.data(), N, -1); for (int i = 0; i < N; ++i) { - const int32 d = internal::SubtleMustCopy(Tin(i)); + const T d = internal::SubtleMustCopy(Tin(i)); OP_REQUIRES(context, FastBoundsCheck(d, N), errors::InvalidArgument(d, " is not between 0 and ", N)); OP_REQUIRES(context, Tout(d) == -1, @@ -73,14 +73,23 @@ class InvertPermutationOp : public OpKernel { REGISTER_KERNEL_BUILDER( Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"), - InvertPermutationOp); + InvertPermutationOp<int32>); +REGISTER_KERNEL_BUILDER( + Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64>("T"), + InvertPermutationOp<int64>); REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .Device(DEVICE_GPU) .TypeConstraint<int32>("T") .HostMemory("x") .HostMemory("y"), - InvertPermutationOp); + InvertPermutationOp<int32>); +REGISTER_KERNEL_BUILDER(Name("InvertPermutation") + .Device(DEVICE_GPU) + .TypeConstraint<int64>("T") + .HostMemory("x") + .HostMemory("y"), + InvertPermutationOp<int64>); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("InvertPermutation") @@ -88,7 +97,13 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .TypeConstraint<int32>("T") .HostMemory("x") .HostMemory("y"), - InvertPermutationOp); + InvertPermutationOp<int32>); +REGISTER_KERNEL_BUILDER(Name("InvertPermutation") + .Device(DEVICE_SYCL) + .TypeConstraint<int64>("T") + .HostMemory("x") + .HostMemory("y"), + InvertPermutationOp<int64>); #endif // TENSORFLOW_USE_SYCL namespace { |