aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 12:44:06 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 12:44:06 -0700
commit60ae24ee1a6c16114de456d77fcfba6f5a1160ca (patch)
tree7b9d5463018055571a5050ca31a8d3df12a3e6fc /unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
parent6e40454a6e6cc57c07c7340148657c985ca6c928 (diff)
Add block evaluation to TensorReshaping/TensorCasting/TensorPadding/TensorSelect
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h55
1 files changed, 44 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
index a8160e17e..cc3e67677 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
@@ -294,23 +294,45 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device>
typedef typename Storage::Type EvaluatorPointerType;
enum {
- IsAligned = false,
- PacketAccess =
+ IsAligned = false,
+ PacketAccess =
#ifndef EIGEN_USE_SYCL
- true,
+ true,
#else
- TensorEvaluator<ArgType, Device>::PacketAccess &
- internal::type_casting_traits<SrcType, TargetType>::VectorizedCast,
+ TensorEvaluator<ArgType, Device>::PacketAccess &
+ internal::type_casting_traits<SrcType, TargetType>::VectorizedCast,
#endif
- BlockAccess = false,
- BlockAccessV2 = false,
- PreferBlockAccess = false,
- Layout = TensorEvaluator<ArgType, Device>::Layout,
- RawAccess = false
+ BlockAccess = false,
+ BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2,
+ PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ RawAccess = false
};
+ static const int NumDims = internal::array_size<Dimensions>::value;
+
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
- typedef internal::TensorBlockNotImplemented TensorBlockV2;
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2
+ ArgTensorBlock;
+
+ struct TensorConversionOpBlockFactory {
+ template <typename ArgXprType>
+ struct XprType {
+ typedef TensorConversionOp<TargetType, const ArgXprType> type;
+ };
+
+ template <typename ArgXprType>
+ typename XprType<ArgXprType>::type expr(const ArgXprType& expr) const {
+ return typename XprType<ArgXprType>::type(expr);
+ }
+ };
+
+ typedef internal::TensorUnaryExprBlock<TensorConversionOpBlockFactory,
+ ArgTensorBlock>
+ TensorBlockV2;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
@@ -376,6 +398,17 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device>
}
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void getResourceRequirements(
+ std::vector<internal::TensorOpResourceRequirements>* resources) const {
+ m_impl.getResourceRequirements(resources);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
+ return TensorBlockV2(m_impl.blockV2(desc, scratch),
+ TensorConversionOpBlockFactory());
+ }
+
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
/// required by sycl in order to extract the sycl accessor