aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_matmul_op_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/batch_matmul_op_impl.h')
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index dfc81a960e..b87c98c374 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -39,6 +39,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace {
@@ -413,6 +416,40 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+template <typename Scalar>
+struct ParallelMatMulKernelSYCL {
+ static void Run(const OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
+ int start, int limit) {
+ auto Tx = in_x.tensor<Scalar, 3>();
+ auto Ty = in_y.tensor<Scalar, 3>();
+ auto Tz = out->tensor<Scalar, 3>();
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
+ contract_pairs[0] = ContractionDims(adj_x, adj_y);
+ auto d = context->eigen_sycl_device();
+ for (int i = start; i < limit; ++i) {
+ auto x = Tx.template chip<0>(i);
+ auto y = Ty.template chip<0>(i);
+ auto z = Tz.template chip<0>(i);
+ z.device(d) = x.contract(y, contract_pairs);
+ }
+ }
+};
+
+template <typename Scalar>
+struct LaunchBatchMatMul<SYCLDevice, Scalar> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+
+ // Number of matrix multiplies i.e. size of the batch.
+ const int64 num_units = in_x.dim_size(0);
+ ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y, out,
+ 0, num_units);
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
template <typename Device, typename Scalar>
class BatchMatMul : public OpKernel {
public:
@@ -492,4 +529,10 @@ class BatchMatMul : public OpKernel {
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<GPUDevice, TYPE>)
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BatchMatMul<SYCLDevice, TYPE>)
+#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow