aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 17:14:37 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 17:14:37 -0700
commit0d2a14ce11c85abdfc68ca37fc66e3cace088b24 (patch)
treeb7838dd2a8dfe6a36f2c85bf9af58c2048df5e5b /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
parent02431cbe71eb036b1d6caa49c585db92a20b030f (diff)
Cleanup Tensor block destination and materialized block storage allocation
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h31
1 files changed, 9 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index 9a1fc9217..58164c13a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -890,24 +890,14 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
return emptyBlock();
}
- // Check if we can reuse `desc` destination, or allocate new scratch buffer.
- ScalarNoConst* materialized_output =
- desc.template destination<ScalarNoConst, Layout>();
- bool materialized_in_output;
+ // Prepare storage for the materialized broadcasting result.
+ const typename TensorBlockV2::Storage block_storage =
+ TensorBlockV2::prepareStorage(desc, scratch);
+ ScalarNoConst* materialized_output = block_storage.data();
- if (materialized_output != NULL) {
- desc.DropDestinationBuffer();
- materialized_in_output = true;
-
- } else {
- materialized_in_output = false;
- const size_t materialized_output_size = desc.size() * sizeof(Scalar);
- void* output_scratch_mem = scratch.allocate(materialized_output_size);
- materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
- }
-
- ScalarNoConst* materialized_input = NULL;
+ // We potentially will need to materialize input blocks.
size_t materialized_input_size = 0;
+ ScalarNoConst* materialized_input = NULL;
// Initialize block broadcating iterator state for outer dimensions (outer
// with regard to bcast dimension). Dimension in this array are always in
@@ -951,11 +941,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
}
- return TensorBlockV2(
- materialized_in_output
- ? internal::TensorBlockKind::kMaterializedInOutput
- : internal::TensorBlockKind::kMaterializedInScratch,
- materialized_output, desc.dimensions());
+ return block_storage.AsTensorMaterializedBlock();
}
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
@@ -1019,7 +1005,8 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
Index output_span;
};
- BlockBroadcastingParams blockBroadcastingParams(TensorBlockDesc& desc) const {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
+ blockBroadcastingParams(TensorBlockDesc& desc) const {
BlockBroadcastingParams params;
params.input_dims = Dimensions(m_impl.dimensions());