diff options
Diffstat (limited to 'tensorflow/core/kernels/batch_matmul_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/batch_matmul_op_impl.h | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index dfc81a960e..b87c98c374 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -39,6 +39,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace { @@ -413,6 +416,40 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +template <typename Scalar> +struct ParallelMatMulKernelSYCL { + static void Run(const OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out, + int start, int limit) { + auto Tx = in_x.tensor<Scalar, 3>(); + auto Ty = in_y.tensor<Scalar, 3>(); + auto Tz = out->tensor<Scalar, 3>(); + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; + contract_pairs[0] = ContractionDims(adj_x, adj_y); + auto d = context->eigen_sycl_device(); + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i); + auto y = Ty.template chip<0>(i); + auto z = Tz.template chip<0>(i); + z.device(d) = x.contract(y, contract_pairs); + } + } +}; + +template <typename Scalar> +struct LaunchBatchMatMul<SYCLDevice, Scalar> { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + + // Number of matrix multiplies i.e. size of the batch. + const int64 num_units = in_x.dim_size(0); + ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y, out, + 0, num_units); + } +}; +#endif // TENSORFLOW_USE_SYCL + template <typename Device, typename Scalar> class BatchMatMul : public OpKernel { public: @@ -492,4 +529,10 @@ class BatchMatMul : public OpKernel { Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ BatchMatMul<GPUDevice, TYPE>) +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ + BatchMatMul<SYCLDevice, TYPE>) +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow |