diff options
author | 2016-07-15 11:06:30 -0800 | |
---|---|---|
committer | 2016-07-15 12:19:10 -0700 | |
commit | bfb0abbedfd866b8adcccc776ef95e0a67c89d14 (patch) | |
tree | 1d281892a1c244711e4a1999c33fb02e5537a385 /tensorflow/core/kernels/transpose_op.cc | |
parent | 7b2cc1c90b450aac5d28262744c4a0044368a495 (diff) |
When transposing a Tensor with a single non-1 dimension, it can be optimized as a reshape.
This is a performance optimization that is important for mobile inference graphs. They typically run with single examples (batch_size=1) and so many operations are performed on Tensors with only 1 non-1 dimension and these are just a reshape. In performance measurements on a Nexus5, this reduced times spent in transpose nodes from around 1ms to 0.02ms for a 1x20000 tensor.
Change: 127564469
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 8078671a98..fbde0c9626 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -123,13 +123,16 @@ void TransposeOp::Compute(OpKernelContext* ctx) { // Check whether permutation is a permutation of integers of [0 .. dims). gtl::InlinedVector<bool, 8> bits(dims); bool is_identity = true; + int32 non_singleton_dims = 0; for (int i = 0; i < dims; ++i) { const int32 d = permutation[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); bits[d] = true; - shape.AddDim(input.dim_size(d)); + const auto dim_size = input.dim_size(d); + shape.AddDim(dim_size); + non_singleton_dims += dim_size != 1 ? 1 : 0; if (d != i) { is_identity = false; } @@ -144,6 +147,12 @@ void TransposeOp::Compute(OpKernelContext* ctx) { if (dims <= 1 || is_identity) { ctx->set_output(0, input); return; + } else if (non_singleton_dims <= 1) { + Tensor output; + OP_REQUIRES(ctx, output.CopyFrom(input, shape), + errors::Unknown("Error reshaping Tensor.")); + ctx->set_output(0, output); + return; } Tensor* output = nullptr; |