aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_op_functor.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-16 09:52:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-16 11:08:17 -0800
commit6804c9cafc11fa73be3fdb057e033f0304661622 (patch)
tree87a1c65806c0bf73263855c659e112977ddb27f1 /tensorflow/core/kernels/transpose_op_functor.h
parentcf661010261c80b97ab68c5aec383b454ef34f18 (diff)
Rewrite of transpose so that its compilation time is tolerable. Main
approach: 1. Do not instantiate templates for all tf types. Instead, various types is casted to one of uint8/uint16/uint32/uint64/string. 2. Use eigen3 for 2/3/4 rank tensors' transpose and fallback to a naive routine which is only templatized on type T but not on NDIMS. Change: 114763098
Diffstat (limited to 'tensorflow/core/kernels/transpose_op_functor.h')
-rw-r--r--tensorflow/core/kernels/transpose_op_functor.h66
1 files changed, 52 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h
index e478c6d966..b79c3c7f2f 100644
--- a/tensorflow/core/kernels/transpose_op_functor.h
+++ b/tensorflow/core/kernels/transpose_op_functor.h
@@ -16,28 +16,66 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
-namespace functor {
-template <typename Device, typename T, int NDIMS>
-void Transpose(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
- // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU.
- Eigen::array<int, NDIMS> p;
- for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
- out.device(d) = in.shuffle(p);
+// Transpose tensor 'in' into tensor 'out' according to dimension
+// permutation 'perm'.
+//
+// REQUIRES: in.dtype() == out->dtype()
+// REQUIRES: in.dims() == out->dims()
+// REQUIRES: in.dims() == perm.size()
+// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
+template <typename Device>
+Status DoTranspose(const Device& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+
+// Implementation details.
+namespace internal {
+
+// Helper to compute 'strides' given a tensor 'shape'. I.e.,
+// strides[i] = prod(shape.dim_size[(i+1):])
+template <typename Index>
+void ComputeStride(const TensorShape& shape, Index* strides) {
+ const int ndims = shape.dims();
+ Index stride = 1;
+ for (int i = ndims - 1; i >= 0; --i) {
+ strides[i] = stride;
+ stride *= static_cast<Index>(shape.dim_size(i));
+ }
}
+// Device-specific naive implementation for tranpose.
+template <typename Device, typename T>
+void TransposeSimple(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+
+// Uses Eigen to transpose.
template <typename Device, typename T, int NDIMS>
-struct TransposeFunctor {
- void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm);
-};
+void TransposeUsingEigen(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
-} // namespace functor
+template <typename Device, typename T>
+void Transpose(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ switch (in.dims()) {
+ case 2:
+ TransposeUsingEigen<Device, T, 2>(d, in, perm, out);
+ break;
+ case 3:
+ TransposeUsingEigen<Device, T, 3>(d, in, perm, out);
+ break;
+ case 4:
+ TransposeUsingEigen<Device, T, 4>(d, in, perm, out);
+ break;
+ default:
+ TransposeSimple<Device, T>(d, in, perm, out);
+ break;
+ }
+}
+} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_