diff options
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 35 |
1 files changed, 10 insertions, 25 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 96c051c636..20f0edf309 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -31,14 +31,13 @@ limitations under the License. namespace tensorflow { -// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of +// inv = InvertPermutationOp(T<int32> p) takes a permutation of // integers 0, 1, ..., n - 1 and returns the inverted // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). // -// REQUIRES: input is a vector of int32 or int64. +// REQUIRES: input is a vector of int32. // REQUIRES: input is a permutation of 0, 1, ..., n-1. -template <typename T> class InvertPermutationOp : public OpKernel { public: explicit InvertPermutationOp(OpKernelConstruction* context) @@ -49,19 +48,20 @@ class InvertPermutationOp : public OpKernel { OP_REQUIRES( context, TensorShapeUtils::IsVector(input.shape()), errors::InvalidArgument("invert_permutation expects a 1D vector.")); - auto Tin = input.vec<T>(); + auto Tin = input.vec<int32>(); OP_REQUIRES(context, FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above. + const int32 N = + static_cast<int32>(Tin.size()); // Safe: bounds-checked above. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); - auto Tout = output->vec<T>(); + auto Tout = output->vec<int32>(); std::fill_n(Tout.data(), N, -1); for (int i = 0; i < N; ++i) { - const T d = internal::SubtleMustCopy(Tin(i)); + const int32 d = internal::SubtleMustCopy(Tin(i)); OP_REQUIRES(context, FastBoundsCheck(d, N), errors::InvalidArgument(d, " is not between 0 and ", N)); OP_REQUIRES(context, Tout(d) == -1, @@ -73,23 +73,14 @@ class InvertPermutationOp : public OpKernel { REGISTER_KERNEL_BUILDER( Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"), - InvertPermutationOp<int32>); -REGISTER_KERNEL_BUILDER( - Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64>("T"), - InvertPermutationOp<int64>); + InvertPermutationOp); REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .Device(DEVICE_GPU) .TypeConstraint<int32>("T") .HostMemory("x") .HostMemory("y"), - InvertPermutationOp<int32>); -REGISTER_KERNEL_BUILDER(Name("InvertPermutation") - .Device(DEVICE_GPU) - .TypeConstraint<int64>("T") - .HostMemory("x") - .HostMemory("y"), - InvertPermutationOp<int64>); + InvertPermutationOp); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("InvertPermutation") @@ -97,13 +88,7 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .TypeConstraint<int32>("T") .HostMemory("x") .HostMemory("y"), - InvertPermutationOp<int32>); -REGISTER_KERNEL_BUILDER(Name("InvertPermutation") - .Device(DEVICE_SYCL) - .TypeConstraint<int64>("T") - .HostMemory("x") - .HostMemory("y"), - InvertPermutationOp<int64>); + InvertPermutationOp); #endif // TENSORFLOW_USE_SYCL namespace { |