diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h index 3880e7ed3..b8c592543 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h @@ -418,12 +418,22 @@ class TensorMaterializedBlock { if (can_use_direct_access) { const Scalar* block_start = data + desc.offset(); - return TensorMaterializedBlock(internal::TensorBlockKind::kView, block_start, - desc.dimensions()); + return TensorMaterializedBlock(internal::TensorBlockKind::kView, + block_start, desc.dimensions()); } else { - void* mem = scratch.allocate(desc.size() * sizeof(Scalar)); - Scalar* block_buffer = static_cast<Scalar*>(mem); + // Try to reuse destination as an output block buffer. + Scalar* block_buffer = desc.template destination<Scalar, Layout>(); + bool materialized_in_output; + + if (block_buffer != NULL) { + materialized_in_output = true; + + } else { + materialized_in_output = false; + void* mem = scratch.allocate(desc.size() * sizeof(Scalar)); + block_buffer = static_cast<Scalar*>(mem); + } typedef internal::TensorBlockIOV2<Scalar, IndexType, NumDims, Layout> TensorBlockIO; @@ -438,8 +448,11 @@ class TensorMaterializedBlock { TensorBlockIO::Copy(dst, src); - return TensorMaterializedBlock(internal::TensorBlockKind::kMaterializedInScratch, - block_buffer, desc.dimensions()); + return TensorMaterializedBlock( + materialized_in_output + ? internal::TensorBlockKind::kMaterializedInOutput + : internal::TensorBlockKind::kMaterializedInScratch, + block_buffer, desc.dimensions()); } } @@ -1141,7 +1154,7 @@ class TensorBlockAssignment { it[idx].count = 0; it[idx].size = target.dims[dim]; it[idx].output_stride = target.strides[dim]; - it[idx].output_span = it[i].output_stride * (it[i].size - 1); + it[idx].output_span = it[idx].output_stride * (it[idx].size - 1); idx++; } |