diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2017-06-28 17:55:23 +0000 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2017-06-28 17:55:23 +0000 |
commit | 53725c10b80dabd2a536f66e854c50f892496946 (patch) | |
tree | bc49b7e7f3372c92a6c8ea07068ed54ef38ab29d /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | |
parent | b8e805497e446e7159f231238b4a8fd22fe70749 (diff) |
Merged in mehdi_goli/opencl/DataDependancy (pull request PR-10)
DataDependancy
* Wrapping data type to the pointer class for sycl in non-terminal nodes; not having that breaks Tensorflow Conv2d code.
* Applying Ronnan's Comments.
* Applying benoit's comments
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 85dfc7a69..744cf7cc4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -38,7 +38,7 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; - + typedef typename XprTraits::PointerType PointerType; enum { Flags = 0 }; @@ -89,6 +89,7 @@ struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; + typedef typename ::Eigen::internal::TypeConversion<Scalar, typename XprTraits::PointerType>::type PointerType; }; template<typename UnaryOp, typename XprType> @@ -161,7 +162,11 @@ struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > typedef typename remove_reference<RhsNested>::type _RhsNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; - + typedef typename ::Eigen::internal::TypeConversion<Scalar, + typename conditional<::Eigen::internal::Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, + typename traits<LhsXprType>::PointerType, + typename traits<RhsXprType>::PointerType>::type + >::type PointerType; enum { Flags = 0 }; @@ -238,7 +243,11 @@ struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprT typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; - + typedef typename ::Eigen::internal::TypeConversion<Scalar, + typename conditional<::Eigen::internal::Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val, + typename traits<Arg2XprType>::PointerType, + typename traits<Arg3XprType>::PointerType>::type + >::type PointerType; enum { Flags = 0 }; @@ -314,6 +323,9 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > typedef typename ElseXprType::Nested ElseNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; + typedef typename conditional<::Eigen::internal::Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val, + typename traits<ThenXprType>::PointerType, + typename traits<ElseXprType>::PointerType>::type PointerType; }; template<typename IfXprType, typename ThenXprType, typename ElseXprType> |