diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_common.h | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 5ad6b1fd4a..c825a91fb1 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -20,6 +20,10 @@ limitations under the License. #define EIGEN_USE_THREADS +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/kernels/cwise_ops_sycl_common.h" +#endif + #include "tensorflow/core/kernels/cwise_ops.h" #include "tensorflow/core/kernels/cwise_ops_gradients.h" @@ -33,6 +37,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif class BinaryOpShared : public OpKernel { public: @@ -96,45 +103,45 @@ class BinaryOp : public BinaryOpShared { if (state.in1_num_elements == 1) { // tensor op scalar functor::BinaryFunctor<Device, Functor, 1>().Right( - eigen_device, out_flat, in0.flat<Tin>(), in1.scalar<Tin>(), - error_ptr); + eigen_device, out_flat, in0.template flat<Tin>(), + in1.template scalar<Tin>(), error_ptr); } else if (state.in0_num_elements == 1) { // scalar op tensor functor::BinaryFunctor<Device, Functor, 1>().Left( - eigen_device, out_flat, in0.scalar<Tin>(), in1.flat<Tin>(), - error_ptr); + eigen_device, out_flat, in0.template scalar<Tin>(), + in1.template flat<Tin>(), error_ptr); } else { functor::BinaryFunctor<Device, Functor, 1>()( - eigen_device, out_flat, in0.flat<Tin>(), in1.flat<Tin>(), - error_ptr); + eigen_device, out_flat, in0.template flat<Tin>(), + in1.template flat<Tin>(), error_ptr); } } else if (ndims == 2) { functor::BinaryFunctor<Device, Functor, 2>().BCast( eigen_device, out->shaped<Tout, 2>(bcast->result_shape()), - in0.shaped<Tin, 2>(bcast->x_reshape()), + in0.template shaped<Tin, 2>(bcast->x_reshape()), BCast::ToIndexArray<2>(bcast->x_bcast()), - in1.shaped<Tin, 2>(bcast->y_reshape()), + in1.template shaped<Tin, 2>(bcast->y_reshape()), BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr); } else if (ndims == 3) { functor::BinaryFunctor<Device, Functor, 3>().BCast( eigen_device, out->shaped<Tout, 3>(bcast->result_shape()), - in0.shaped<Tin, 3>(bcast->x_reshape()), + in0.template shaped<Tin, 3>(bcast->x_reshape()), BCast::ToIndexArray<3>(bcast->x_bcast()), - in1.shaped<Tin, 3>(bcast->y_reshape()), + in1.template shaped<Tin, 3>(bcast->y_reshape()), BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr); } else if (ndims == 4) { functor::BinaryFunctor<Device, Functor, 4>().BCast( eigen_device, out->shaped<Tout, 4>(bcast->result_shape()), - in0.shaped<Tin, 4>(bcast->x_reshape()), + in0.template shaped<Tin, 4>(bcast->x_reshape()), BCast::ToIndexArray<4>(bcast->x_bcast()), - in1.shaped<Tin, 4>(bcast->y_reshape()), + in1.template shaped<Tin, 4>(bcast->y_reshape()), BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr); } else if (ndims == 5) { functor::BinaryFunctor<Device, Functor, 5>().BCast( eigen_device, out->shaped<Tout, 5>(bcast->result_shape()), - in0.shaped<Tin, 5>(bcast->x_reshape()), + in0.template shaped<Tin, 5>(bcast->x_reshape()), BCast::ToIndexArray<5>(bcast->x_bcast()), - in1.shaped<Tin, 5>(bcast->y_reshape()), + in1.template shaped<Tin, 5>(bcast->y_reshape()), BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr); } else { SetUnimplementedError(ctx); |