aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-17 16:44:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 08:47:13 -0700
commit82035055f81b207b2f93b0c00877e407e5b7399f (patch)
tree51fd9fe02db1beab6b75bfe85b7456c6d1242a57
parente677983885784f8e246f66a3fb1ee5439cf605ef (diff)
Add fast path for identity transpose.
Change: 117505457
-rw-r--r--tensorflow/core/kernels/transpose_op.cc15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 5ecef9c6f9..3eaa4777af 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -107,18 +107,23 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
errors::InvalidArgument(
"transpose expects a vector of size ", input.dims(),
". But input(1) is a vector of size ", Vperm.size()));
- gtl::ArraySlice<int32> permutation(
- reinterpret_cast<const int32*>(Vperm.data()), dims);
+ const int32* perm_begin = reinterpret_cast<const int32*>(Vperm.data());
+ const std::vector<int32> permutation(perm_begin, perm_begin + dims);
TensorShape shape;
// Check whether permutation is a permutation of integers of [0 .. dims).
gtl::InlinedVector<bool, 8> bits(dims);
- for (const int32 d : permutation) {
+ bool is_identity = true;
+ for (int i = 0; i < dims; ++i) {
+ const int32 d = permutation[i];
OP_REQUIRES(
ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
bits[d] = true;
shape.AddDim(input.dim_size(d));
+ if (d != i) {
+ is_identity = false;
+ }
}
for (int i = 0; i < dims; ++i) {
OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
@@ -126,8 +131,8 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
str_util::Join(permutation, ","), "}."));
}
- // 0-D and 1-D transposes do nothing
- if (dims <= 1) {
+ // 0-D, 1-D, and identity transposes do nothing.
+ if (dims <= 1 || is_identity) {
ctx->set_output(0, input);
return;
}