aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_functor.h
blob: add4635331ee55495f5bc0d79fd040812078f1f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_

#include <numeric>
#include <string>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/logging.h"

namespace tensorflow {
// Transpose tensor 'in' into tensor 'out' according to dimension
// permutation 'perm'.
//
// REQUIRES: in.dtype() == out->dtype()
// REQUIRES: in.dims() == out->dims()
// REQUIRES: in.dims() == perm.size()
// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
template <typename Device>
Status DoTranspose(const Device& device, const Tensor& in,
                   const gtl::ArraySlice<int32> perm, Tensor* out);

// Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
// permutation 'perm'.
//
// REQUIRES: in.dtype() == out->dtype()
// REQUIRES: in.dims() == out->dims()
// REQUIRES: in.dims() == perm.size()
// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
template <typename Device>
Status DoConjugateTranspose(const Device& device, const Tensor& in,
                            const gtl::ArraySlice<int32> perm, Tensor* out);

// Convenience versions of DoTranspose that only swap the last (inner) two
// dimensions.
template <typename Device>
Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);

// Convenience versions of DoConjugateTranspose that only swap the last (inner)
// two dimensions.
template <typename Device>
Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
                                  Tensor* out);

// Primary device specific functor to be specialized for each device and type.
template <typename Device, typename T, bool conjugate = false>
struct Transpose {
  static void run(const Device& d, const Tensor& in,
                  const gtl::ArraySlice<int32> perm, Tensor* out);
};

// Implementation details.
namespace internal {

typedef gtl::InlinedVector<int64, 8> TransposeDimsVec;
typedef gtl::InlinedVector<int32, 8> TransposePermsVec;

// Helper function that takes a tensor shape, a permutation, combines the
// neighboring shapes if their indices in the permutation are consecutive.
// The function outputs the combined shape and new permutation.
// Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will
// produce new shape {2, 60, 120} and new permutation {0, 2, 1}.
inline void ReduceTransposeDimensions(const TensorShape& shape,
                                      gtl::ArraySlice<int32> perm,
                                      TransposePermsVec* new_perm,
                                      TransposeDimsVec* new_dims) {
  CHECK_EQ(shape.dims(), perm.size());
  if (shape.dims() == 1) {
    // If input dimension is already 1, no need to reduce dimension.
    new_perm->resize(1);
    (*new_perm)[0] = perm[0];
    (*new_dims)[0] = shape.dim_size(0);
    return;
  }
  TransposePermsVec new_dim_position(shape.dims(), -1);
  TransposeDimsVec combined_dims(shape.dims(), 0);
  int cur_head = perm[0];
  new_dim_position[cur_head] = 0;
  combined_dims[0] = shape.dim_size(cur_head);
  int dim_idx = 0;
  for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) {
    // If two indices in permutation are consecutive numbers, combine their
    // dimensions.
    if (cur_head + 1 == perm[perm_idx]) {
      cur_head = perm[perm_idx];
      combined_dims[dim_idx] *= shape.dim_size(cur_head);
    } else {
      // Else start a new dimension.
      cur_head = perm[perm_idx];
      dim_idx++;
      new_dim_position[cur_head] = dim_idx;
      combined_dims[dim_idx] = shape.dim_size(cur_head);
    }
  }
  // Compact the new permutations and dimension sizes.
  new_perm->resize(dim_idx + 1);
  new_dims->resize(dim_idx + 1);
  dim_idx = 0;
  for (int i = 0; i < new_dim_position.size(); ++i) {
    if (new_dim_position[i] >= 0) {
      int new_perm_idx = new_dim_position[i];
      (*new_perm)[dim_idx] = new_perm_idx;
      (*new_dims)[dim_idx] = combined_dims[new_perm_idx];
      dim_idx++;
    }
  }
}

// If all non-singleton dimensions remain in ascending order, the shuffled
// singletons can be transposed by a reshape, saving a memory allocation & copy.
// |permutation| must be a permutation of {0, .., input_shape.dims() - 1}.
// That is, for all i, 0 <= perm[i] < input_shape.dims().
// In practice, this is checked in TransposeOp::Compute prior to calling this
// function, and the function sits here to facilitate unit testing.
inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
                                        const std::vector<int32>& permutation) {
  int last_nonsingleton_perm_dim = -1;
  for (int perm_dim : permutation) {
    if (input_shape.dim_size(perm_dim) == 1) {
      continue;
    }
    if (perm_dim < last_nonsingleton_perm_dim) {
      return false;
    }
    last_nonsingleton_perm_dim = perm_dim;
  }
  return true;
}

// Uses Eigen to transpose.
template <typename Device, typename T, int NDIMS>
void TransposeUsingEigen(const Device& d, const Tensor& in,
                         const gtl::ArraySlice<int32> perm, bool conjugate,
                         Tensor* out) {
  Eigen::array<int, NDIMS> p;
  for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
  auto x = typename TTypes<T, NDIMS>::ConstTensor(
      reinterpret_cast<const T*>(in.tensor_data().data()),
      in.shape().AsEigenDSizes<NDIMS>());
  auto y = typename TTypes<T, NDIMS>::Tensor(
      reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
      out->shape().AsEigenDSizes<NDIMS>());
  if (conjugate) {
    y.device(d) = x.conjugate().shuffle(p);
  } else {
    y.device(d) = x.shuffle(p);
  }
}

template <typename Device>
Status DoTransposeImpl(const Device& d, const Tensor& in,
                       const gtl::ArraySlice<int32> perm, bool conjugate,
                       Tensor* out) {
  CHECK_GE(in.dims(), 2);
  CHECK_EQ(in.dims(), out->dims());
  CHECK_EQ(in.dims(), perm.size());
  CHECK_EQ(in.dtype(), out->dtype());
  switch (in.dtype()) {
    case DT_BOOL:
    case DT_INT8:
    case DT_QINT8:
    case DT_QUINT8:
    case DT_UINT8:
      Transpose<Device, uint8>::run(d, in, perm, out);
      break;

    case DT_BFLOAT16:
    case DT_HALF:
    case DT_INT16:
    case DT_QINT16:
    case DT_QUINT16:
    case DT_UINT16:
      Transpose<Device, uint16>::run(d, in, perm, out);
      break;

    case DT_FLOAT:
    case DT_INT32:
    case DT_QINT32:
      Transpose<Device, uint32>::run(d, in, perm, out);
      break;

    case DT_DOUBLE:
    case DT_INT64:
      Transpose<Device, uint64>::run(d, in, perm, out);
      break;

    case DT_COMPLEX64:
      if (conjugate) {
#if defined(__ANDROID__) and !defined(__clang__)
        // Workaround for GCC compiler bug in Android toolchain.
        return errors::Unimplemented(
            "Conjugate transpose of complex64 not supported for GCC on "
            "Android.");
#else
        Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
#endif
      } else {
        Transpose<Device, uint64>::run(d, in, perm, out);
      }
      break;

    case DT_COMPLEX128:
      if (conjugate) {
        Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
                                                               out);
      } else {
        Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
                                                                out);
      }
      break;

    case DT_STRING:
      Transpose<Device, string>::run(d, in, perm, out);
      break;

    default:
      return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
  }
  return Status::OK();
}

template <typename Device>
inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
                                    bool conjugate, Tensor* out) {
  const int ndims = in.dims();
  if (ndims == 0) return Status::OK();
  TransposePermsVec perm(ndims);
  std::iota(perm.begin(), perm.end(), 0);
  std::swap(perm[ndims - 2], perm[ndims - 1]);
  return DoTransposeImpl(device, in, perm, conjugate, out);
}

#ifdef TENSORFLOW_USE_SYCL
// For SYCL lets always go through Eigen
template <typename Device, typename T>
void TransposeSYCL(const Device& d, const Tensor& in,
                   const gtl::ArraySlice<int32> perm, bool conjugate,
                   Tensor* out);
#endif  // TENSORFLOW_USE_SYCL

}  // namespace internal
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_