aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.cc')
-rw-r--r--tensorflow/core/kernels/matmul_op.cc60
1 files changed, 56 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index a2b0127fac..94fe22ed31 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -46,6 +46,9 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T, bool USE_CUBLAS>
struct LaunchMatMul;
@@ -118,27 +121,42 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
return false;
}
-// On CPUs, we ignore USE_CUBLAS
-template <typename T>
-struct LaunchMatMulCPU {
+template <typename Device, typename T>
+struct LaunchMatMulBase {
static void launch(
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
+#ifndef TENSORFLOW_USE_SYCL
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
if (!was_vector) {
- functor::MatMulFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(),
+#endif // TENSORFLOW_USE_SYCL
+ functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
out->matrix<T>(), a.matrix<T>(),
b.matrix<T>(), dim_pair);
+#ifndef TENSORFLOW_USE_SYCL
}
+#endif // TENSORFLOW_USE_SYCL
}
};
+// On CPUs, we ignore USE_CUBLAS
+template <typename T>
+struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
+
template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
+
+template <typename T, bool USE_CUBLAS>
+struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
template <typename T>
@@ -256,6 +274,20 @@ struct MatMulFunctor<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
+template <typename T>
+struct MatMulFunctor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
#define REGISTER_CPU(T) \
@@ -276,6 +308,12 @@ struct MatMulFunctor<CPUDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
+#if defined (INTEL_MKL)
+// MKL does not support half and int32 types for matrix-multiplication, so
+// register the kernel to use default Eigen based implementations for these types
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_int32(REGISTER_CPU);
+#else
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -283,6 +321,7 @@ TF_CALL_half(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
+#endif
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
@@ -294,4 +333,17 @@ TF_CALL_half(REGISTER_GPU);
#endif
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
+ MatMulOp<SYCLDevice, T, false /* xxblas */>); \
+ REGISTER_KERNEL_BUILDER(Name("MatMul") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .Label("eigen"), \
+ MatMulOp<SYCLDevice, T, false /* xxblas */>)
+TF_CALL_float(REGISTER_SYCL);
+
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow