diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-09 12:45:31 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-09 12:45:31 -0700 |
commit | 33e174613987cfc6c83576dc0fe8086c7a5d1b1f (patch) | |
tree | 4f4c62eab5c0feca0f233624c9c1fc571c491781 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | |
parent | f0a4642baba70a64128964d96c4ede012614925e (diff) |
Block evaluation for TensorChipping + fixed bugs in TensorPadding and TensorSlicing
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index b1d668744..b77d8fe84 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -53,18 +53,22 @@ struct TensorEvaluator RawAccess = true }; - typedef typename internal::TensorBlock< - typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout> + typedef typename internal::remove_const<Scalar>::type ScalarNoConst; + + typedef typename internal::TensorBlock<ScalarNoConst, Index, NumCoords, Layout> TensorBlock; - typedef typename internal::TensorBlockReader< - typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout> + typedef typename internal::TensorBlockReader<ScalarNoConst, Index, NumCoords, Layout> TensorBlockReader; - typedef typename internal::TensorBlockWriter< - typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout> + typedef typename internal::TensorBlockWriter<ScalarNoConst, Index, NumCoords, Layout> TensorBlockWriter; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; + + typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords, + Layout, Index> + TensorBlockV2; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) @@ -161,6 +165,12 @@ struct TensorEvaluator TensorBlockReader::Run(block, m_data); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const { + assert(m_data != NULL); + return TensorBlockV2::materialize(m_data, m_dims, desc, scratch); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writeBlock( const TensorBlock& block) { assert(m_data != NULL); @@ -269,11 +279,6 @@ struct TensorEvaluator<const Derived, Device> typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc; typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; - typedef internal::TensorBlockIOV2<ScalarNoConst, Index, NumCoords, Layout> - TensorBlockIO; - typedef typename TensorBlockIO::Dst TensorBlockIODst; - typedef typename TensorBlockIO::Src TensorBlockIOSrc; - typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords, Layout, Index> TensorBlockV2; |