aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_functor_cpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/transpose_functor_cpu.cc')
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index f8c87e7e2e..30b82f1843 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -114,4 +114,28 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
return Status::OK();
}
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+
+template <>
+Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+
+ case DT_FLOAT:
+ case DT_INT32:
+ internal::Transpose<SYCLDevice, uint32>(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on SYCL: ", in.dtype());
+ }
+ return Status::OK();
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow