aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_functor_cpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/transpose_functor_cpu.cc')
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc51
1 files changed, 32 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index 248c11976e..a004cb2293 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/transpose_functor.h"
+#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
namespace internal {
@@ -24,10 +25,8 @@ template <typename Device, typename T>
void TransposeSimple(const Device& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
const int ndims = in.dims();
- gtl::InlinedVector<int64, 8> in_strides(ndims);
- ComputeStride(in.shape(), in_strides.data());
- gtl::InlinedVector<int64, 8> out_strides(ndims);
- ComputeStride(out->shape(), out_strides.data());
+ gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
+ gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
const int64 nelem = in.NumElements();
const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
@@ -45,20 +44,6 @@ void TransposeSimple(const Device& d, const Tensor& in,
}
}
-template <typename Device, typename T, int NDIMS>
-void TransposeUsingEigen(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- Eigen::array<int, NDIMS> p;
- for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
- auto x = typename TTypes<T, NDIMS>::ConstTensor(
- reinterpret_cast<const T*>(in.tensor_data().data()),
- in.shape().AsEigenDSizes<NDIMS>());
- auto y = typename TTypes<T, NDIMS>::Tensor(
- reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
- out->shape().AsEigenDSizes<NDIMS>());
- y.device(d) = x.shuffle(p);
-}
-
} // end namespace internal
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -182,7 +167,35 @@ template <typename T>
struct Transpose<SYCLDevice, T> {
static void run(const SYCLDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
- // Should add a specialized implementation for SYCLDevice here.
+ switch (in.dims()) {
+ case 1:
+ internal::TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, out);
+ break;
+ case 2:
+ internal::TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, out);
+ break;
+ case 3:
+ internal::TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, out);
+ break;
+ case 4:
+ internal::TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, out);
+ break;
+ case 5:
+ internal::TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, out);
+ break;
+ case 6:
+ internal::TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, out);
+ break;
+ case 7:
+ internal::TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, out);
+ break;
+ case 8:
+ internal::TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, out);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
+ break;
+ }
}
};