diff options
Diffstat (limited to 'tensorflow/core/kernels/transpose_functor_cpu.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_functor_cpu.cc | 51 |
1 files changed, 32 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index 248c11976e..a004cb2293 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { namespace internal { @@ -24,10 +25,8 @@ template <typename Device, typename T> void TransposeSimple(const Device& d, const Tensor& in, const gtl::ArraySlice<int32> perm, Tensor* out) { const int ndims = in.dims(); - gtl::InlinedVector<int64, 8> in_strides(ndims); - ComputeStride(in.shape(), in_strides.data()); - gtl::InlinedVector<int64, 8> out_strides(ndims); - ComputeStride(out->shape(), out_strides.data()); + gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape()); + gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape()); const int64 nelem = in.NumElements(); const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); @@ -45,20 +44,6 @@ void TransposeSimple(const Device& d, const Tensor& in, } } -template <typename Device, typename T, int NDIMS> -void TransposeUsingEigen(const Device& d, const Tensor& in, - const gtl::ArraySlice<int32> perm, Tensor* out) { - Eigen::array<int, NDIMS> p; - for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; - auto x = typename TTypes<T, NDIMS>::ConstTensor( - reinterpret_cast<const T*>(in.tensor_data().data()), - in.shape().AsEigenDSizes<NDIMS>()); - auto y = typename TTypes<T, NDIMS>::Tensor( - reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())), - out->shape().AsEigenDSizes<NDIMS>()); - y.device(d) = x.shuffle(p); -} - } // end namespace internal typedef Eigen::ThreadPoolDevice CPUDevice; @@ -182,7 +167,35 @@ template <typename T> struct Transpose<SYCLDevice, T> { static void run(const SYCLDevice& d, const Tensor& in, const gtl::ArraySlice<int32> perm, Tensor* out) { - // Should add a specialized implementation for SYCLDevice here. + switch (in.dims()) { + case 1: + internal::TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, out); + break; + case 2: + internal::TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, out); + break; + case 3: + internal::TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, out); + break; + case 4: + internal::TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, out); + break; + case 5: + internal::TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, out); + break; + case 6: + internal::TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, out); + break; + case 7: + internal::TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, out); + break; + case 8: + internal::TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, out); + break; + default: + LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims(); + break; + } } }; |