diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-16 17:14:37 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-16 17:14:37 -0700 |
commit | 0d2a14ce11c85abdfc68ca37fc66e3cace088b24 (patch) | |
tree | b7838dd2a8dfe6a36f2c85bf9af58c2048df5e5b /unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h | |
parent | 02431cbe71eb036b1d6caa49c585db92a20b030f (diff) |
Cleanup Tensor block destination and materialized block storage allocation
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h | 34 |
1 files changed, 9 insertions, 25 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h index 99c74fc67..1104f02c7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h @@ -238,22 +238,6 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device desc.dimensions()); } - // Check if we can reuse `desc` destination, or allocate new scratch buffer. - ScalarNoConst* materialized_output = - desc.template destination<ScalarNoConst, Layout>(); - bool materialized_in_output; - - if (materialized_output != NULL) { - desc.DropDestinationBuffer(); - materialized_in_output = true; - - } else { - const size_t materialized_output_size = desc.size() * sizeof(Scalar); - void* output_scratch_mem = scratch.allocate(materialized_output_size); - materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem); - materialized_in_output = false; - } - static const bool IsColMajor = Layout == static_cast<int>(ColMajor); Index offset = desc.offset(); @@ -363,6 +347,10 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device typedef internal::StridedLinearBufferCopy<ScalarNoConst, Index> LinCopy; + // Prepare storage for the materialized padding result. + const typename TensorBlockV2::Storage block_storage = + TensorBlockV2::prepareStorage(desc, scratch); + // Iterate copying data from `m_impl.data()` to the output buffer. for (Index size = 0; size < output_size; size += output_inner_dim_size) { // Detect if we are in the padded region (exclude innermost dimension). @@ -376,7 +364,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device if (is_padded) { // Fill with padding value. LinCopy::template Run<LinCopy::Kind::FillLinear>( - typename LinCopy::Dst(output_offset, 1, materialized_output), + typename LinCopy::Dst(output_offset, 1, block_storage.data()), typename LinCopy::Src(0, 0, &m_paddingValue), output_inner_dim_size); @@ -385,7 +373,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device const Index out = output_offset; LinCopy::template Run<LinCopy::Kind::FillLinear>( - typename LinCopy::Dst(out, 1, materialized_output), + typename LinCopy::Dst(out, 1, block_storage.data()), typename LinCopy::Src(0, 0, &m_paddingValue), output_inner_pad_before_size); } @@ -397,7 +385,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device eigen_assert(output_inner_copy_size == 0 || m_impl.data() != NULL); LinCopy::template Run<LinCopy::Kind::Linear>( - typename LinCopy::Dst(out, 1, materialized_output), + typename LinCopy::Dst(out, 1, block_storage.data()), typename LinCopy::Src(in, 1, m_impl.data()), output_inner_copy_size); } @@ -407,7 +395,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device output_inner_copy_size; LinCopy::template Run<LinCopy::Kind::FillLinear>( - typename LinCopy::Dst(out, 1, materialized_output), + typename LinCopy::Dst(out, 1, block_storage.data()), typename LinCopy::Src(0, 0, &m_paddingValue), output_inner_pad_after_size); } @@ -431,11 +419,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device } } - return TensorBlockV2(materialized_in_output - ? internal::TensorBlockKind::kMaterializedInOutput - : internal::TensorBlockKind::kMaterializedInScratch, - materialized_output, - desc.dimensions()); + return block_storage.AsTensorMaterializedBlock(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return NULL; } |