diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-16 17:14:37 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-16 17:14:37 -0700 |
commit | 0d2a14ce11c85abdfc68ca37fc66e3cace088b24 (patch) | |
tree | b7838dd2a8dfe6a36f2c85bf9af58c2048df5e5b /unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h | |
parent | 02431cbe71eb036b1d6caa49c585db92a20b030f (diff) |
Cleanup Tensor block destination and materialized block storage allocation
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h | 58 |
1 files changed, 6 insertions, 52 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h index bb9908b62..df4cd1eb3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h @@ -351,66 +351,20 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device> typedef typename TensorBlockIO::Dst TensorBlockIODst; typedef typename TensorBlockIO::Src TensorBlockIOSrc; - ScalarNoConst* block_buffer = NULL; - typename TensorBlockIO::Dimensions block_strides; - - bool materialized_in_output = false; - bool has_valid_materialized_expr = true; - - if (desc.HasDestinationBuffer()) { - // Check if we can reuse destination buffer for block materialization. - const typename TensorBlockDesc::DestinationBuffer& destination_buffer = - desc.GetDestinationBuffer(); - - const bool dims_match = dimensions_match( - desc.dimensions(), destination_buffer.template dimensions<Scalar>()); - - const bool strides_match = - dimensions_match(internal::strides<Layout>(desc.dimensions()), - destination_buffer.template strides<Scalar>()); - - if (dims_match && strides_match) { - // Destination buffer fits the block contiguously. - materialized_in_output = true; - has_valid_materialized_expr = true; - block_buffer = destination_buffer.template data<ScalarNoConst>(); - block_strides = internal::strides<Layout>(desc.dimensions()); - eigen_assert(block_buffer != NULL); - - } else if (dims_match && root_of_expr_ast) { - // Destination buffer has strides not matching the block strides, but - // for the root of the expression tree it's safe to materialize anyway. - materialized_in_output = true; - has_valid_materialized_expr = false; - block_buffer = destination_buffer.template data<ScalarNoConst>(); - block_strides = destination_buffer.template strides<ScalarNoConst>(); - eigen_assert(block_buffer != NULL); - } - - if (materialized_in_output) desc.DropDestinationBuffer(); - } - - // If we were not able to reuse destination buffer, allocate temporary - // buffer for block evaluation using scratch allocator. - if (!materialized_in_output) { - void* mem = scratch.allocate(desc.size() * sizeof(ScalarNoConst)); - block_buffer = static_cast<ScalarNoConst*>(mem); - block_strides = internal::strides<Layout>(desc.dimensions()); - } + const typename TensorBlockV2::Storage block_storage = + TensorBlockV2::prepareStorage( + desc, scratch, /*allow_strided_storage=*/root_of_expr_ast); typename TensorBlockIO::Dimensions input_strides(m_unshuffledInputStrides); TensorBlockIOSrc src(input_strides, m_impl.data(), srcCoeff(desc.offset())); - TensorBlockIODst dst(desc.dimensions(), block_strides, block_buffer); + TensorBlockIODst dst(block_storage.dimensions(), block_storage.strides(), + block_storage.data()); typename TensorBlockIO::DimensionsMap dst_to_src_dim_map(m_shuffle); TensorBlockIO::Copy(dst, src, dst_to_src_dim_map); - return TensorBlockV2( - materialized_in_output - ? internal::TensorBlockKind::kMaterializedInOutput - : internal::TensorBlockKind::kMaterializedInScratch, - block_buffer, desc.dimensions(), has_valid_materialized_expr); + return block_storage.AsTensorMaterializedBlock(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { |