aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_sycl_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_sycl_common.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops_sycl_common.h8
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h
index 4c22cc4855..3fcf0759d4 100644
--- a/tensorflow/core/kernels/cwise_ops_sycl_common.h
+++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h
@@ -21,12 +21,10 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
#define EIGEN_USE_SYCL
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
-
-#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cwise_ops.h"
-#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -62,14 +60,14 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in0,
typename Functor::tin_type in1, bool* error) {
- To32Bit(out).device(d) = To32Bit(in0).binaryExpr(in1, typename Functor::func());
+ To32Bit(out).device(d) = To32Bit(in0).binaryExpr(To32Bit(in1), typename Functor::func());
}
void Left(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tscalar_type scalar,
typename Functor::tin_type in, bool* error) {
typedef typename Functor::func Binary;
- constexpr int NumDims = Functor::tin_type::NumDimensions;
+ constexpr int NumDims = Functor::tin_type::NumDimensions;
typedef typename Functor::tin_type::Scalar T;
typedef typename Functor::tin_type::Index Index;
Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>();