diff options
author | 2016-03-17 16:44:00 -0800 | |
---|---|---|
committer | 2016-03-18 08:47:13 -0700 | |
commit | 82035055f81b207b2f93b0c00877e407e5b7399f (patch) | |
tree | 51fd9fe02db1beab6b75bfe85b7456c6d1242a57 | |
parent | e677983885784f8e246f66a3fb1ee5439cf605ef (diff) |
Add fast path for identity transpose.
Change: 117505457
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 5ecef9c6f9..3eaa4777af 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -107,18 +107,23 @@ void TransposeOp::Compute(OpKernelContext* ctx) { errors::InvalidArgument( "transpose expects a vector of size ", input.dims(), ". But input(1) is a vector of size ", Vperm.size())); - gtl::ArraySlice<int32> permutation( - reinterpret_cast<const int32*>(Vperm.data()), dims); + const int32* perm_begin = reinterpret_cast<const int32*>(Vperm.data()); + const std::vector<int32> permutation(perm_begin, perm_begin + dims); TensorShape shape; // Check whether permutation is a permutation of integers of [0 .. dims). gtl::InlinedVector<bool, 8> bits(dims); - for (const int32 d : permutation) { + bool is_identity = true; + for (int i = 0; i < dims; ++i) { + const int32 d = permutation[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); bits[d] = true; shape.AddDim(input.dim_size(d)); + if (d != i) { + is_identity = false; + } } for (int i = 0; i < dims; ++i) { OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( @@ -126,8 +131,8 @@ void TransposeOp::Compute(OpKernelContext* ctx) { str_util::Join(permutation, ","), "}.")); } - // 0-D and 1-D transposes do nothing - if (dims <= 1) { + // 0-D, 1-D, and identity transposes do nothing. + if (dims <= 1 || is_identity) { ctx->set_output(0, input); return; } |