diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h index a51c88540..ae3ab5f81 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h @@ -370,21 +370,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device static const Index inner_dim_idx = isColMajor ? 0 : NumDims - 1; const bool inner_dim_reversed = m_reverse[inner_dim_idx]; - // Try to reuse destination as an output block buffer. - CoeffReturnType* block_buffer = - desc.template destination<CoeffReturnType, Layout>(); - bool materialized_in_output; - - if (block_buffer != NULL) { - desc.DropDestinationBuffer(); - materialized_in_output = true; - - } else { - materialized_in_output = false; - void* mem = scratch.allocate(desc.size() * sizeof(CoeffReturnType)); - block_buffer = static_cast<CoeffReturnType*>(mem); - } - // Offset in the output block. Index block_offset = 0; @@ -438,6 +423,11 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device const Index inner_dim_size = it[effective_inner_dim].size; + // Prepare storage for the materialized reverse result. + const typename TensorBlockV2::Storage block_storage = + TensorBlockV2::prepareStorage(desc, scratch); + CoeffReturnType* block_buffer = block_storage.data(); + while (it[NumDims - 1].count < it[NumDims - 1].size) { // Copy inner-most dimension data from reversed location in input. Index dst = block_offset; @@ -475,11 +465,7 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device } } - return TensorBlockV2( - materialized_in_output - ? internal::TensorBlockKind::kMaterializedInOutput - : internal::TensorBlockKind::kMaterializedInScratch, - block_buffer, desc.dimensions()); + return block_storage.AsTensorMaterializedBlock(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { |