aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r--tensorflow/core/kernels/transpose_op.cc35
1 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 20f0edf309..96c051c636 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -31,13 +31,14 @@ limitations under the License.
namespace tensorflow {
-// inv = InvertPermutationOp(T<int32> p) takes a permutation of
+// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
// integers 0, 1, ..., n - 1 and returns the inverted
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
//
-// REQUIRES: input is a vector of int32.
+// REQUIRES: input is a vector of int32 or int64.
// REQUIRES: input is a permutation of 0, 1, ..., n-1.
+template <typename T>
class InvertPermutationOp : public OpKernel {
public:
explicit InvertPermutationOp(OpKernelConstruction* context)
@@ -48,20 +49,19 @@ class InvertPermutationOp : public OpKernel {
OP_REQUIRES(
context, TensorShapeUtils::IsVector(input.shape()),
errors::InvalidArgument("invert_permutation expects a 1D vector."));
- auto Tin = input.vec<int32>();
+ auto Tin = input.vec<T>();
OP_REQUIRES(context,
FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
errors::InvalidArgument("permutation of nonnegative int32s "
"must have <= int32 max elements"));
- const int32 N =
- static_cast<int32>(Tin.size()); // Safe: bounds-checked above.
+ const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above.
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
- auto Tout = output->vec<int32>();
+ auto Tout = output->vec<T>();
std::fill_n(Tout.data(), N, -1);
for (int i = 0; i < N; ++i) {
- const int32 d = internal::SubtleMustCopy(Tin(i));
+ const T d = internal::SubtleMustCopy(Tin(i));
OP_REQUIRES(context, FastBoundsCheck(d, N),
errors::InvalidArgument(d, " is not between 0 and ", N));
OP_REQUIRES(context, Tout(d) == -1,
@@ -73,14 +73,23 @@ class InvertPermutationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(
Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
- InvertPermutationOp);
+ InvertPermutationOp<int32>);
+REGISTER_KERNEL_BUILDER(
+ Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
+ InvertPermutationOp<int64>);
REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.HostMemory("x")
.HostMemory("y"),
- InvertPermutationOp);
+ InvertPermutationOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .HostMemory("x")
+ .HostMemory("y"),
+ InvertPermutationOp<int64>);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
@@ -88,7 +97,13 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
.TypeConstraint<int32>("T")
.HostMemory("x")
.HostMemory("y"),
- InvertPermutationOp);
+ InvertPermutationOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int64>("T")
+ .HostMemory("x")
+ .HostMemory("y"),
+ InvertPermutationOp<int64>);
#endif // TENSORFLOW_USE_SYCL
namespace {