diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-02-16 09:52:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-16 11:08:17 -0800 |
commit | 6804c9cafc11fa73be3fdb057e033f0304661622 (patch) | |
tree | 87a1c65806c0bf73263855c659e112977ddb27f1 /tensorflow/core/kernels/transpose_op_functor.h | |
parent | cf661010261c80b97ab68c5aec383b454ef34f18 (diff) |
Rewrite of transpose so that its compilation time is tolerable. Main
approach:
1. Do not instantiate templates for all tf types. Instead, various
types is casted to one of uint8/uint16/uint32/uint64/string.
2. Use eigen3 for 2/3/4 rank tensors' transpose and fallback to a
naive routine which is only templatized on type T but not on
NDIMS.
Change: 114763098
Diffstat (limited to 'tensorflow/core/kernels/transpose_op_functor.h')
-rw-r--r-- | tensorflow/core/kernels/transpose_op_functor.h | 66 |
1 files changed, 52 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h index e478c6d966..b79c3c7f2f 100644 --- a/tensorflow/core/kernels/transpose_op_functor.h +++ b/tensorflow/core/kernels/transpose_op_functor.h @@ -16,28 +16,66 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { -namespace functor { -template <typename Device, typename T, int NDIMS> -void Transpose(const Device& d, typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) { - // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU. - Eigen::array<int, NDIMS> p; - for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; - out.device(d) = in.shuffle(p); +// Transpose tensor 'in' into tensor 'out' according to dimension +// permutation 'perm'. +// +// REQUIRES: in.dtype() == out->dtype() +// REQUIRES: in.dims() == out->dims() +// REQUIRES: in.dims() == perm.size() +// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) +template <typename Device> +Status DoTranspose(const Device& device, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); + +// Implementation details. +namespace internal { + +// Helper to compute 'strides' given a tensor 'shape'. I.e., +// strides[i] = prod(shape.dim_size[(i+1):]) +template <typename Index> +void ComputeStride(const TensorShape& shape, Index* strides) { + const int ndims = shape.dims(); + Index stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast<Index>(shape.dim_size(i)); + } } +// Device-specific naive implementation for tranpose. +template <typename Device, typename T> +void TransposeSimple(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); + +// Uses Eigen to transpose. template <typename Device, typename T, int NDIMS> -struct TransposeFunctor { - void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm); -}; +void TransposeUsingEigen(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); -} // namespace functor +template <typename Device, typename T> +void Transpose(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + switch (in.dims()) { + case 2: + TransposeUsingEigen<Device, T, 2>(d, in, perm, out); + break; + case 3: + TransposeUsingEigen<Device, T, 3>(d, in, perm, out); + break; + case 4: + TransposeUsingEigen<Device, T, 4>(d, in, perm, out); + break; + default: + TransposeSimple<Device, T>(d, in, perm, out); + break; + } +} +} // namespace internal } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ |