diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_sycl_common.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_sycl_common.h | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h index 4c22cc4855..3fcf0759d4 100644 --- a/tensorflow/core/kernels/cwise_ops_sycl_common.h +++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h @@ -21,12 +21,10 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_ #define EIGEN_USE_SYCL +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" - -#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cwise_ops.h" -#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -62,14 +60,14 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> { void operator()(const SYCLDevice& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1, bool* error) { - To32Bit(out).device(d) = To32Bit(in0).binaryExpr(in1, typename Functor::func()); + To32Bit(out).device(d) = To32Bit(in0).binaryExpr(To32Bit(in1), typename Functor::func()); } void Left(const SYCLDevice& d, typename Functor::tout_type out, typename Functor::tscalar_type scalar, typename Functor::tin_type in, bool* error) { typedef typename Functor::func Binary; - constexpr int NumDims = Functor::tin_type::NumDimensions; + constexpr int NumDims = Functor::tin_type::NumDimensions; typedef typename Functor::tin_type::Scalar T; typedef typename Functor::tin_type::Index Index; Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>(); |