aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 17:14:37 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 17:14:37 -0700
commit0d2a14ce11c85abdfc68ca37fc66e3cace088b24 (patch)
treeb7838dd2a8dfe6a36f2c85bf9af58c2048df5e5b /unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
parent02431cbe71eb036b1d6caa49c585db92a20b030f (diff)
Cleanup Tensor block destination and materialized block storage allocation
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h58
1 files changed, 6 insertions, 52 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
index bb9908b62..df4cd1eb3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
@@ -351,66 +351,20 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
typedef typename TensorBlockIO::Dst TensorBlockIODst;
typedef typename TensorBlockIO::Src TensorBlockIOSrc;
- ScalarNoConst* block_buffer = NULL;
- typename TensorBlockIO::Dimensions block_strides;
-
- bool materialized_in_output = false;
- bool has_valid_materialized_expr = true;
-
- if (desc.HasDestinationBuffer()) {
- // Check if we can reuse destination buffer for block materialization.
- const typename TensorBlockDesc::DestinationBuffer& destination_buffer =
- desc.GetDestinationBuffer();
-
- const bool dims_match = dimensions_match(
- desc.dimensions(), destination_buffer.template dimensions<Scalar>());
-
- const bool strides_match =
- dimensions_match(internal::strides<Layout>(desc.dimensions()),
- destination_buffer.template strides<Scalar>());
-
- if (dims_match && strides_match) {
- // Destination buffer fits the block contiguously.
- materialized_in_output = true;
- has_valid_materialized_expr = true;
- block_buffer = destination_buffer.template data<ScalarNoConst>();
- block_strides = internal::strides<Layout>(desc.dimensions());
- eigen_assert(block_buffer != NULL);
-
- } else if (dims_match && root_of_expr_ast) {
- // Destination buffer has strides not matching the block strides, but
- // for the root of the expression tree it's safe to materialize anyway.
- materialized_in_output = true;
- has_valid_materialized_expr = false;
- block_buffer = destination_buffer.template data<ScalarNoConst>();
- block_strides = destination_buffer.template strides<ScalarNoConst>();
- eigen_assert(block_buffer != NULL);
- }
-
- if (materialized_in_output) desc.DropDestinationBuffer();
- }
-
- // If we were not able to reuse destination buffer, allocate temporary
- // buffer for block evaluation using scratch allocator.
- if (!materialized_in_output) {
- void* mem = scratch.allocate(desc.size() * sizeof(ScalarNoConst));
- block_buffer = static_cast<ScalarNoConst*>(mem);
- block_strides = internal::strides<Layout>(desc.dimensions());
- }
+ const typename TensorBlockV2::Storage block_storage =
+ TensorBlockV2::prepareStorage(
+ desc, scratch, /*allow_strided_storage=*/root_of_expr_ast);
typename TensorBlockIO::Dimensions input_strides(m_unshuffledInputStrides);
TensorBlockIOSrc src(input_strides, m_impl.data(), srcCoeff(desc.offset()));
- TensorBlockIODst dst(desc.dimensions(), block_strides, block_buffer);
+ TensorBlockIODst dst(block_storage.dimensions(), block_storage.strides(),
+ block_storage.data());
typename TensorBlockIO::DimensionsMap dst_to_src_dim_map(m_shuffle);
TensorBlockIO::Copy(dst, src, dst_to_src_dim_map);
- return TensorBlockV2(
- materialized_in_output
- ? internal::TensorBlockKind::kMaterializedInOutput
- : internal::TensorBlockKind::kMaterializedInScratch,
- block_buffer, desc.dimensions(), has_valid_materialized_expr);
+ return block_storage.AsTensorMaterializedBlock();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {