aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-15 11:06:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-15 12:19:10 -0700
commitbfb0abbedfd866b8adcccc776ef95e0a67c89d14 (patch)
tree1d281892a1c244711e4a1999c33fb02e5537a385 /tensorflow/core/kernels/transpose_op.cc
parent7b2cc1c90b450aac5d28262744c4a0044368a495 (diff)
When transposing a Tensor with a single non-1 dimension, it can be optimized as a reshape.
This is a performance optimization that is important for mobile inference graphs. They typically run with single examples (batch_size=1) and so many operations are performed on Tensors with only 1 non-1 dimension and these are just a reshape. In performance measurements on a Nexus5, this reduced times spent in transpose nodes from around 1ms to 0.02ms for a 1x20000 tensor. Change: 127564469
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r--tensorflow/core/kernels/transpose_op.cc11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 8078671a98..fbde0c9626 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -123,13 +123,16 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
// Check whether permutation is a permutation of integers of [0 .. dims).
gtl::InlinedVector<bool, 8> bits(dims);
bool is_identity = true;
+ int32 non_singleton_dims = 0;
for (int i = 0; i < dims; ++i) {
const int32 d = permutation[i];
OP_REQUIRES(
ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
bits[d] = true;
- shape.AddDim(input.dim_size(d));
+ const auto dim_size = input.dim_size(d);
+ shape.AddDim(dim_size);
+ non_singleton_dims += dim_size != 1 ? 1 : 0;
if (d != i) {
is_identity = false;
}
@@ -144,6 +147,12 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
if (dims <= 1 || is_identity) {
ctx->set_output(0, input);
return;
+ } else if (non_singleton_dims <= 1) {
+ Tensor output;
+ OP_REQUIRES(ctx, output.CopyFrom(input, shape),
+ errors::Unknown("Error reshaping Tensor."));
+ ctx->set_output(0, output);
+ return;
}
Tensor* output = nullptr;