diff options
Diffstat (limited to 'tensorflow/core/kernels/transpose_functor_cpu.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_functor_cpu.cc | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index f8c87e7e2e..30b82f1843 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -114,4 +114,28 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in, return Status::OK(); } +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; + +template <> +Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + CHECK_GE(in.dims(), 2); + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + + case DT_FLOAT: + case DT_INT32: + internal::Transpose<SYCLDevice, uint32>(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on SYCL: ", in.dtype()); + } + return Status::OK(); +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow |