aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h35
1 files changed, 21 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 5ad6b1fd4a..c825a91fb1 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -20,6 +20,10 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#ifdef TENSORFLOW_USE_SYCL
+#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
+#endif
+
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
@@ -33,6 +37,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif
class BinaryOpShared : public OpKernel {
public:
@@ -96,45 +103,45 @@ class BinaryOp : public BinaryOpShared {
if (state.in1_num_elements == 1) {
// tensor op scalar
functor::BinaryFunctor<Device, Functor, 1>().Right(
- eigen_device, out_flat, in0.flat<Tin>(), in1.scalar<Tin>(),
- error_ptr);
+ eigen_device, out_flat, in0.template flat<Tin>(),
+ in1.template scalar<Tin>(), error_ptr);
} else if (state.in0_num_elements == 1) {
// scalar op tensor
functor::BinaryFunctor<Device, Functor, 1>().Left(
- eigen_device, out_flat, in0.scalar<Tin>(), in1.flat<Tin>(),
- error_ptr);
+ eigen_device, out_flat, in0.template scalar<Tin>(),
+ in1.template flat<Tin>(), error_ptr);
} else {
functor::BinaryFunctor<Device, Functor, 1>()(
- eigen_device, out_flat, in0.flat<Tin>(), in1.flat<Tin>(),
- error_ptr);
+ eigen_device, out_flat, in0.template flat<Tin>(),
+ in1.template flat<Tin>(), error_ptr);
}
} else if (ndims == 2) {
functor::BinaryFunctor<Device, Functor, 2>().BCast(
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
- in0.shaped<Tin, 2>(bcast->x_reshape()),
+ in0.template shaped<Tin, 2>(bcast->x_reshape()),
BCast::ToIndexArray<2>(bcast->x_bcast()),
- in1.shaped<Tin, 2>(bcast->y_reshape()),
+ in1.template shaped<Tin, 2>(bcast->y_reshape()),
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
} else if (ndims == 3) {
functor::BinaryFunctor<Device, Functor, 3>().BCast(
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
- in0.shaped<Tin, 3>(bcast->x_reshape()),
+ in0.template shaped<Tin, 3>(bcast->x_reshape()),
BCast::ToIndexArray<3>(bcast->x_bcast()),
- in1.shaped<Tin, 3>(bcast->y_reshape()),
+ in1.template shaped<Tin, 3>(bcast->y_reshape()),
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
} else if (ndims == 4) {
functor::BinaryFunctor<Device, Functor, 4>().BCast(
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
- in0.shaped<Tin, 4>(bcast->x_reshape()),
+ in0.template shaped<Tin, 4>(bcast->x_reshape()),
BCast::ToIndexArray<4>(bcast->x_bcast()),
- in1.shaped<Tin, 4>(bcast->y_reshape()),
+ in1.template shaped<Tin, 4>(bcast->y_reshape()),
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
} else if (ndims == 5) {
functor::BinaryFunctor<Device, Functor, 5>().BCast(
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
- in0.shaped<Tin, 5>(bcast->x_reshape()),
+ in0.template shaped<Tin, 5>(bcast->x_reshape()),
BCast::ToIndexArray<5>(bcast->x_bcast()),
- in1.shaped<Tin, 5>(bcast->y_reshape()),
+ in1.template shaped<Tin, 5>(bcast->y_reshape()),
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
} else {
SetUnimplementedError(ctx);