aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h27
1 files changed, 20 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
index 3880e7ed3..b8c592543 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
@@ -418,12 +418,22 @@ class TensorMaterializedBlock {
if (can_use_direct_access) {
const Scalar* block_start = data + desc.offset();
- return TensorMaterializedBlock(internal::TensorBlockKind::kView, block_start,
- desc.dimensions());
+ return TensorMaterializedBlock(internal::TensorBlockKind::kView,
+ block_start, desc.dimensions());
} else {
- void* mem = scratch.allocate(desc.size() * sizeof(Scalar));
- Scalar* block_buffer = static_cast<Scalar*>(mem);
+ // Try to reuse destination as an output block buffer.
+ Scalar* block_buffer = desc.template destination<Scalar, Layout>();
+ bool materialized_in_output;
+
+ if (block_buffer != NULL) {
+ materialized_in_output = true;
+
+ } else {
+ materialized_in_output = false;
+ void* mem = scratch.allocate(desc.size() * sizeof(Scalar));
+ block_buffer = static_cast<Scalar*>(mem);
+ }
typedef internal::TensorBlockIOV2<Scalar, IndexType, NumDims, Layout>
TensorBlockIO;
@@ -438,8 +448,11 @@ class TensorMaterializedBlock {
TensorBlockIO::Copy(dst, src);
- return TensorMaterializedBlock(internal::TensorBlockKind::kMaterializedInScratch,
- block_buffer, desc.dimensions());
+ return TensorMaterializedBlock(
+ materialized_in_output
+ ? internal::TensorBlockKind::kMaterializedInOutput
+ : internal::TensorBlockKind::kMaterializedInScratch,
+ block_buffer, desc.dimensions());
}
}
@@ -1141,7 +1154,7 @@ class TensorBlockAssignment {
it[idx].count = 0;
it[idx].size = target.dims[dim];
it[idx].output_stride = target.strides[dim];
- it[idx].output_span = it[i].output_stride * (it[i].size - 1);
+ it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
idx++;
}