aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 13:26:28 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-16 13:26:28 -0700
commit02431cbe71eb036b1d6caa49c585db92a20b030f (patch)
treeae380ae099cf016ec1554befbab74111279e7857 /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
parentd380c23b2cc0b02e10819e779c73cde2c62603b2 (diff)
TensorBroadcasting support for random/uniform blocks
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h616
1 files changed, 358 insertions, 258 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index cc0a00e8d..9a1fc9217 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -884,60 +884,187 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch,
bool /*root_of_expr_ast*/ = false) const {
- static const bool
- is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
+ BlockBroadcastingParams params = blockBroadcastingParams(desc);
- // Return a block with a single scalar.
- if (NumDims <= 0) return scalarBlock(scratch);
+ if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
+ return emptyBlock();
+ }
- // Because we only support kSkewedInnerDims blocking, block size should be
- // equal to m_dimensions for inner dims, a smaller than m_dimensions[i] size
- // for the first outer dim, and 1 for other outer dims. This is guaranteed
- // by MergeResourceRequirements() in TensorBlock.h.
- const Dimensions& output_dims = desc.dimensions();
- const Dimensions output_strides = internal::strides<Layout>(output_dims);
+ // Check if we can reuse `desc` destination, or allocate new scratch buffer.
+ ScalarNoConst* materialized_output =
+ desc.template destination<ScalarNoConst, Layout>();
+ bool materialized_in_output;
- // Find where outer dims start.
- int outer_dim_start = 0;
- Index outer_dim_size = 1;
- Index inner_dim_size = 1;
+ if (materialized_output != NULL) {
+ desc.DropDestinationBuffer();
+ materialized_in_output = true;
- for (int i = 0; i < NumDims; ++i) {
- const int dim = is_col_major ? i : NumDims - i - 1;
+ } else {
+ materialized_in_output = false;
+ const size_t materialized_output_size = desc.size() * sizeof(Scalar);
+ void* output_scratch_mem = scratch.allocate(materialized_output_size);
+ materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
+ }
- if (i > outer_dim_start) {
- eigen_assert(output_dims[dim] == 1);
- } else if (output_dims[dim] != m_dimensions[dim]) {
- eigen_assert(output_dims[dim] < m_dimensions[dim]);
- outer_dim_size = output_dims[dim];
- } else {
- inner_dim_size *= output_dims[dim];
- ++outer_dim_start;
- }
+ ScalarNoConst* materialized_input = NULL;
+ size_t materialized_input_size = 0;
+
+ // Initialize block broadcating iterator state for outer dimensions (outer
+ // with regard to bcast dimension). Dimension in this array are always in
+ // inner_most -> outer_most order (col major layout).
+ array<BlockBroadcastingIteratorState, NumDims> it;
+ int idx = 0;
+
+ for (int i = params.inner_dim_count + 1; i < NumDims; ++i) {
+ const Index dim = IsColMajor ? i : NumDims - 1 - i;
+ it[idx].size = params.output_dims[dim];
+ it[idx].count = 0;
+ it[idx].output_stride = m_outputStrides[dim];
+ it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
+ idx++;
}
- if (inner_dim_size == 0 || outer_dim_size == 0) {
- return emptyBlock();
+ // Write output into the beginning of `materialized_output`.
+ Index output_offset = 0;
+
+ // We will fill output block by broadcasting along the bcast dim, and
+ // iterating over outer dimension.
+ const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
+
+ for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
+ ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
+ Index bcast_offset = desc.offset() + output_offset;
+
+ // Broadcast along the bcast dimension.
+ num_output_coeffs += BroadcastBlockAlongBcastDim(
+ params, bcast_offset, scratch, bcast_output, &materialized_input,
+ &materialized_input_size);
+
+ // Switch to the next outer dimension.
+ for (int j = 0; j < idx; ++j) {
+ if (++it[j].count < it[j].size) {
+ output_offset += it[j].output_stride;
+ break;
+ }
+ it[j].count = 0;
+ output_offset -= it[j].output_span;
+ }
}
- const Dimensions& input_dims = Dimensions(m_impl.dimensions());
+ return TensorBlockV2(
+ materialized_in_output
+ ? internal::TensorBlockKind::kMaterializedInOutput
+ : internal::TensorBlockKind::kMaterializedInScratch,
+ materialized_output, desc.dimensions());
+ }
- // Pre-fill input_block_sizes, broadcast_block_sizes,
- // broadcast_block_strides, and broadcast_tensor_strides. Later on we will
- // only modify the outer_dim_start-th dimension on these arrays.
+ EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
- // Calculate the input block size for looking into the input.
+ const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
+
+ Broadcast functor() const { return m_broadcast; }
+#ifdef EIGEN_USE_SYCL
+ // binding placeholder accessors to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
+ cl::sycl::handler& cgh) const {
+ m_impl.bind(cgh);
+ }
+#endif
+ private:
+ static const bool IsColMajor =
+ static_cast<int>(Layout) == static_cast<int>(ColMajor);
+
+ // We will build a general case block broadcasting on top of broadcasting
+ // primitive that will do broadcasting only for the inner dimension(s) along
+ // the first dimension smaller than the input size (it's called `bcast_dim`).
+ //
+ // Example:
+ // dim: 0 1 2 (ColMajor)
+ // input size: [9, 3, 6]
+ // block size: [9, 2, 6]
+ //
+ // We will compute broadcasted block by iterating over the outer dimensions
+ // before `bcast_dim` (only dimension `2` in this example) and computing
+ // broadcasts along the `bcast_dim` (dimension `1` in this example).
+
+ // BlockBroadcastingParams holds precomputed parameters for broadcasting a
+ // single block along the broadcasting dimension. Sizes and strides along the
+ // `bcast_dim` might be invalid, they will be adjusted later in
+ // `BroadcastBlockAlongBcastDim`.
+ struct BlockBroadcastingParams {
+ Dimensions input_dims; // input expression dimensions
+ Dimensions output_dims; // output block sizes
+ Dimensions output_strides; // output block strides
+
+ int inner_dim_count; // count inner dimensions matching in size
+ int bcast_dim; // broadcasting dimension index
+ Index bcast_dim_size; // broadcasting dimension size
+ Index inner_dim_size; // inner dimensions size
+
+ // Block sizes and strides for the input block where all dimensions before
+ // `bcast_dim` are equal to `1`.
Dimensions input_block_sizes;
- for (int i = 0; i < outer_dim_start; ++i) {
- const int dim = is_col_major ? i : NumDims -i - 1;
- input_block_sizes[dim] = input_dims[dim];
+ Dimensions input_block_strides;
+
+ // Block sizes and strides for blocks with extra dimensions and strides `0`.
+ BroadcastDimensions bcast_block_sizes;
+ BroadcastDimensions bcast_block_strides;
+ BroadcastDimensions bcast_input_strides;
+ };
+
+ struct BlockBroadcastingIteratorState {
+ Index size;
+ Index count;
+ Index output_stride;
+ Index output_span;
+ };
+
+ BlockBroadcastingParams blockBroadcastingParams(TensorBlockDesc& desc) const {
+ BlockBroadcastingParams params;
+
+ params.input_dims = Dimensions(m_impl.dimensions());
+
+ // Output block sizes and strides.
+ params.output_dims = desc.dimensions();
+ params.output_strides = internal::strides<Layout>(params.output_dims);
+
+ // Find the broadcasting dimension (first dimension with output size smaller
+ // that the input size).
+ params.bcast_dim = 0;
+ params.bcast_dim_size = 1;
+ params.inner_dim_size = 1;
+
+ // Count the number of inner dimensions that have the same size in the block
+ // and in the broadcast expression.
+ params.inner_dim_count = 0;
+
+ for (int i = 0; i < NumDims; ++i) {
+ const int dim = IsColMajor ? i : NumDims - i - 1;
+
+ if (params.output_dims[dim] == m_dimensions[dim]) {
+ params.inner_dim_size *= params.output_dims[dim];
+ ++params.inner_dim_count;
+ continue;
+ }
+
+ // First non-matching dimension is the broadcasting dimension.
+ eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
+ params.bcast_dim = dim;
+ params.bcast_dim_size = params.output_dims[dim];
+ break;
+ }
+
+ // Calculate the input block size for looking into the input.
+ for (int i = 0; i < params.inner_dim_count; ++i) {
+ const int dim = IsColMajor ? i : NumDims - i - 1;
+ params.input_block_sizes[dim] = params.input_dims[dim];
}
- for (int i = outer_dim_start; i < NumDims; ++i) {
- const int dim = is_col_major ? i : NumDims -i - 1;
- input_block_sizes[dim] = 1;
+ for (int i = params.inner_dim_count; i < NumDims; ++i) {
+ const int dim = IsColMajor ? i : NumDims - i - 1;
+ params.input_block_sizes[dim] = 1;
}
- Dimensions input_block_strides =
- internal::strides<Layout>(input_block_sizes);
+ params.input_block_strides =
+ internal::strides<Layout>(params.input_block_sizes);
// Broadcast with the 0-stride trick: Create 1 extra dim for each
// broadcast, set the input stride to 0.
@@ -957,268 +1084,241 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
// input_block_strides[1], 0,
// ...].
//
- BroadcastDimensions bcast_block_sizes;
- BroadcastDimensions bcast_block_strides;
- BroadcastDimensions bcast_input_strides;
-
- for (int i = 0; i < outer_dim_start; ++i) {
- const int dim = is_col_major ? i : NumDims - i - 1;
-
- const int copy_dim = is_col_major ? 2 * i : 2 * NumDims - 2 * i - 1;
- const int broadcast_dim = is_col_major ? copy_dim + 1 : copy_dim - 1;
-
- bcast_block_sizes[copy_dim] = input_dims[dim];
- bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
- bcast_block_strides[copy_dim] = output_strides[dim];
- bcast_block_strides[broadcast_dim] =
- output_strides[dim] * input_dims[dim];
- bcast_input_strides[copy_dim] = input_block_strides[dim];
- bcast_input_strides[broadcast_dim] = 0;
+ for (int i = 0; i < params.inner_dim_count; ++i) {
+ const int dim = IsColMajor ? i : NumDims - i - 1;
+
+ const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
+ const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
+
+ params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
+ params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
+ params.bcast_block_strides[copy_dim] = params.output_strides[dim];
+ params.bcast_block_strides[broadcast_dim] =
+ params.output_strides[dim] * params.input_dims[dim];
+ params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
+ params.bcast_input_strides[broadcast_dim] = 0;
}
- for (int i = 2 * outer_dim_start; i < 2 * NumDims; ++i) {
- const int dim = is_col_major ? i : 2 * NumDims - i - 1;
- bcast_block_sizes[dim] = 1;
- bcast_block_strides[dim] = 0;
- bcast_input_strides[dim] = 0;
+
+ for (int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
+ const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
+ params.bcast_block_sizes[dim] = 1;
+ params.bcast_block_strides[dim] = 0;
+ params.bcast_input_strides[dim] = 0;
}
- const int outer_dim =
- is_col_major ? outer_dim_start : NumDims - outer_dim_start - 1;
+ return params;
+ }
- // Check if we can reuse `desc` destination, or allocate new scratch buffer.
- ScalarNoConst* materialized_output =
- desc.template destination<ScalarNoConst, Layout>();
- bool materialized_in_output;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlock(
+ const Dimensions& input_block_sizes,
+ const BroadcastDimensions& broadcast_block_sizes,
+ const BroadcastDimensions& broadcast_block_strides,
+ const BroadcastDimensions& broadcast_tensor_strides, Index offset,
+ TensorBlock* output_block) const {
+ TensorBlock input_view_block(
+ static_cast<int>(Layout) == static_cast<int>(ColMajor)
+ ? indexColMajor(output_block->first_coeff_index() + offset)
+ : indexRowMajor(output_block->first_coeff_index() + offset),
+ input_block_sizes, Dimensions(m_inputStrides),
+ Dimensions(m_inputStrides), NULL);
- if (materialized_output != NULL) {
- desc.DropDestinationBuffer();
- materialized_in_output = true;
+ internal::TensorBlockView<ArgType, Device> input_block(m_device, m_impl,
+ input_view_block);
+ BroadcastTensorBlock broadcast_block(
+ 0, broadcast_block_sizes, broadcast_block_strides,
+ broadcast_tensor_strides, output_block->data() + offset);
- } else {
- materialized_in_output = false;
- const size_t materialized_output_size = desc.size() * sizeof(Scalar);
- void* output_scratch_mem = scratch.allocate(materialized_output_size);
- materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
- }
+ BroadcastTensorBlockReader::Run(&broadcast_block, input_block.data());
+ }
- size_t materialized_input_size = 0;
- ScalarNoConst* materialized_input = NULL;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 emptyBlock() const {
+ DSizes<Index, NumDims> dimensions;
+ for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
+ return TensorBlockV2(internal::TensorBlockKind::kView, NULL, dimensions);
+ }
- if (outer_dim_size == 1) {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
+ BlockBroadcastingParams params, Index bcast_offset,
+ TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
+ ScalarNoConst** materialized_input,
+ size_t* materialized_input_size) const {
+ if (params.bcast_dim_size == 1) {
// We just need one block read using the ready-set values above.
- BroadcastBlockV2(
- input_block_sizes, input_block_strides, bcast_block_sizes,
- bcast_block_strides, bcast_input_strides, 0, desc, scratch,
- materialized_output, &materialized_input, &materialized_input_size);
-
- } else if (input_dims[outer_dim] == 1) {
- // Broadcast outer_dim_start-th dimension (< NumDims) by outer_dim_size.
- const int broadcast_outer_dim =
- is_col_major ? 2 * outer_dim_start + 1
- : 2 * NumDims - 2 * outer_dim_start - 2;
-
- bcast_block_sizes[broadcast_outer_dim] = outer_dim_size;
- bcast_input_strides[broadcast_outer_dim] = 0;
- bcast_block_strides[broadcast_outer_dim] = output_strides[outer_dim];
-
- BroadcastBlockV2(
- input_block_sizes, input_block_strides, bcast_block_sizes,
- bcast_block_strides, bcast_input_strides, 0, desc, scratch,
- materialized_output, &materialized_input, &materialized_input_size);
+ return BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, 0, scratch,
+ materialized_output, materialized_input, materialized_input_size);
+
+ } else if (params.input_dims[params.bcast_dim] == 1) {
+ // Broadcast bcast dimension (< NumDims) by bcast_dim_size.
+ const int broadcast_bcast_dim =
+ IsColMajor ? 2 * params.inner_dim_count + 1
+ : 2 * NumDims - 2 * params.inner_dim_count - 2;
+
+ params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
+ params.bcast_input_strides[broadcast_bcast_dim] = 0;
+ params.bcast_block_strides[broadcast_bcast_dim] =
+ params.output_strides[params.bcast_dim];
+
+ return BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, 0, scratch,
+ materialized_output, materialized_input, materialized_input_size);
} else {
- // The general case. Let's denote the output block as x[...,
- // a:a+outer_dim_size, :, ..., :], where a:a+outer_dim_size is a slice on
- // the outer_dim_start-th dimension (< NumDims). We need to split the
- // a:a+outer_dim_size into possibly 3 sub-blocks:
+ // Keep track of the total number of the coefficients written to the
+ // output block.
+ Index num_output_coeffs = 0;
+
+ // The general case. Let's denote the output block as
+ //
+ // x[..., a:a+bcast_dim_size, :, ..., :]
+ //
+ // where a:a+bcast_dim_size is a slice on the bcast_dim dimension
+ // (< NumDims). We need to split the a:a+bcast_dim_size into possibly 3
+ // sub-blocks:
//
// (1) a:b, where b is the smallest multiple of
- // input_dims[outer_dim_start] in [a, a+outer_dim_size].
+ // input_dims[bcast_dim_start] in [a, a+bcast_dim_size].
//
- // (2) b:c, where c is the largest multiple of input_dims[outer_dim_start]
- // in [a, a+outer_dim_size].
+ // (2) b:c, where c is the largest multiple of input_dims[bcast_dim_start]
+ // in [a, a+bcast_dim_size].
//
- // (3) c:a+outer_dim_size .
+ // (3) c:a+bcast_dim_size .
//
// Or, when b and c do not exist, we just need to process the whole block
// together.
// Find a.
- const Index outer_dim_left_index =
- desc.offset() / m_outputStrides[outer_dim];
+ const Index bcast_dim_left_index =
+ bcast_offset / m_outputStrides[params.bcast_dim];
// Find b and c.
- const Index input_outer_dim_size = input_dims[outer_dim];
+ const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
- // First multiple after a. This is b when <= outer_dim_left_index +
- // outer_dim_size.
+ // First multiple after a. This is b when <= bcast_dim_left_index +
+ // bcast_dim_size.
const Index first_multiple =
- divup<Index>(outer_dim_left_index, input_outer_dim_size) *
- input_outer_dim_size;
+ divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
+ input_bcast_dim_size;
- if (first_multiple <= outer_dim_left_index + outer_dim_size) {
+ if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
// b exists, so does c. Find it.
- const Index last_multiple = (outer_dim_left_index + outer_dim_size) /
- input_outer_dim_size * input_outer_dim_size;
- const int copy_outer_dim = is_col_major
- ? 2 * outer_dim_start
- : 2 * NumDims - 2 * outer_dim_start - 1;
- const int broadcast_outer_dim =
- is_col_major ? 2 * outer_dim_start + 1
- : 2 * NumDims - 2 * outer_dim_start - 2;
-
- if (first_multiple > outer_dim_left_index) {
- const Index head_size = first_multiple - outer_dim_left_index;
- input_block_sizes[outer_dim] = head_size;
- bcast_block_sizes[copy_outer_dim] = head_size;
- bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
- bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
- bcast_block_sizes[broadcast_outer_dim] = 1;
- bcast_input_strides[broadcast_outer_dim] = 0;
- bcast_block_strides[broadcast_outer_dim] =
- output_strides[outer_dim] * input_dims[outer_dim];
-
- BroadcastBlockV2(input_block_sizes, input_block_strides,
- bcast_block_sizes, bcast_block_strides,
- bcast_input_strides, 0, desc, scratch,
- materialized_output, &materialized_input,
- &materialized_input_size);
+ const Index last_multiple =
+ (bcast_dim_left_index + params.bcast_dim_size) /
+ input_bcast_dim_size * input_bcast_dim_size;
+ const int copy_bcast_dim =
+ IsColMajor ? 2 * params.inner_dim_count
+ : 2 * NumDims - 2 * params.inner_dim_count - 1;
+ const int broadcast_bcast_dim =
+ IsColMajor ? 2 * params.inner_dim_count + 1
+ : 2 * NumDims - 2 * params.inner_dim_count - 2;
+
+ if (first_multiple > bcast_dim_left_index) {
+ const Index head_size = first_multiple - bcast_dim_left_index;
+ params.input_block_sizes[params.bcast_dim] = head_size;
+ params.bcast_block_sizes[copy_bcast_dim] = head_size;
+ params.bcast_input_strides[copy_bcast_dim] =
+ params.input_block_strides[params.bcast_dim];
+ params.bcast_block_strides[copy_bcast_dim] =
+ params.output_strides[params.bcast_dim];
+ params.bcast_block_sizes[broadcast_bcast_dim] = 1;
+ params.bcast_input_strides[broadcast_bcast_dim] = 0;
+ params.bcast_block_strides[broadcast_bcast_dim] =
+ params.output_strides[params.bcast_dim] *
+ params.input_dims[params.bcast_dim];
+
+ num_output_coeffs += BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, 0, scratch,
+ materialized_output, materialized_input, materialized_input_size);
}
if (first_multiple < last_multiple) {
- input_block_sizes[outer_dim] = input_outer_dim_size;
- bcast_block_sizes[copy_outer_dim] = input_outer_dim_size;
- bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
- bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
- bcast_block_sizes[broadcast_outer_dim] =
- (last_multiple - first_multiple) / input_outer_dim_size;
- bcast_input_strides[broadcast_outer_dim] = 0;
- bcast_block_strides[broadcast_outer_dim] =
- output_strides[outer_dim] * input_dims[outer_dim];
- const Index offset = (first_multiple - outer_dim_left_index) *
- m_outputStrides[outer_dim];
-
- BroadcastBlockV2(input_block_sizes, input_block_strides,
- bcast_block_sizes, bcast_block_strides,
- bcast_input_strides, offset, desc, scratch,
- materialized_output, &materialized_input,
- &materialized_input_size);
+ params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
+ params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
+ params.bcast_input_strides[copy_bcast_dim] =
+ params.input_block_strides[params.bcast_dim];
+ params.bcast_block_strides[copy_bcast_dim] =
+ params.output_strides[params.bcast_dim];
+ params.bcast_block_sizes[broadcast_bcast_dim] =
+ (last_multiple - first_multiple) / input_bcast_dim_size;
+ params.bcast_input_strides[broadcast_bcast_dim] = 0;
+ params.bcast_block_strides[broadcast_bcast_dim] =
+ params.output_strides[params.bcast_dim] *
+ params.input_dims[params.bcast_dim];
+ const Index offset = (first_multiple - bcast_dim_left_index) *
+ m_outputStrides[params.bcast_dim];
+
+ num_output_coeffs += BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, offset, scratch,
+ materialized_output, materialized_input, materialized_input_size);
}
- if (last_multiple < outer_dim_left_index + outer_dim_size) {
+ if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
const Index tail_size =
- outer_dim_left_index + outer_dim_size - last_multiple;
- input_block_sizes[outer_dim] = tail_size;
- bcast_block_sizes[copy_outer_dim] = tail_size;
- bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
- bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
- bcast_block_sizes[broadcast_outer_dim] = 1;
- bcast_input_strides[broadcast_outer_dim] = 0;
- bcast_block_strides[broadcast_outer_dim] =
- output_strides[outer_dim] * input_dims[outer_dim];
- const Index offset = (last_multiple - outer_dim_left_index) *
- m_outputStrides[outer_dim];
-
- BroadcastBlockV2(input_block_sizes, input_block_strides,
- bcast_block_sizes, bcast_block_strides,
- bcast_input_strides, offset, desc, scratch,
- materialized_output, &materialized_input,
- &materialized_input_size);
+ bcast_dim_left_index + params.bcast_dim_size - last_multiple;
+ params.input_block_sizes[params.bcast_dim] = tail_size;
+ params.bcast_block_sizes[copy_bcast_dim] = tail_size;
+ params.bcast_input_strides[copy_bcast_dim] =
+ params.input_block_strides[params.bcast_dim];
+ params.bcast_block_strides[copy_bcast_dim] =
+ params.output_strides[params.bcast_dim];
+ params.bcast_block_sizes[broadcast_bcast_dim] = 1;
+ params.bcast_input_strides[broadcast_bcast_dim] = 0;
+ params.bcast_block_strides[broadcast_bcast_dim] =
+ params.output_strides[params.bcast_dim] *
+ params.input_dims[params.bcast_dim];
+ const Index offset = (last_multiple - bcast_dim_left_index) *
+ m_outputStrides[params.bcast_dim];
+
+ num_output_coeffs += BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, offset, scratch,
+ materialized_output, materialized_input, materialized_input_size);
}
} else {
// b and c do not exist.
- const int copy_outer_dim = is_col_major
- ? 2 * outer_dim_start
- : 2 * NumDims - 2 * outer_dim_start - 1;
- input_block_sizes[outer_dim] = outer_dim_size;
- bcast_block_sizes[copy_outer_dim] = outer_dim_size;
- bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
- bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
-
- BroadcastBlockV2(
- input_block_sizes, input_block_strides, bcast_block_sizes,
- bcast_block_strides, bcast_input_strides, 0, desc, scratch,
- materialized_output, &materialized_input, &materialized_input_size);
+ const int copy_bcast_dim =
+ IsColMajor ? 2 * params.inner_dim_count
+ : 2 * NumDims - 2 * params.inner_dim_count - 1;
+ params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
+ params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
+ params.bcast_input_strides[copy_bcast_dim] =
+ params.input_block_strides[params.bcast_dim];
+ params.bcast_block_strides[copy_bcast_dim] =
+ params.output_strides[params.bcast_dim];
+
+ num_output_coeffs += BroadcastBlockV2(
+ params.input_block_sizes, params.input_block_strides,
+ params.bcast_block_sizes, params.bcast_block_strides,
+ params.bcast_input_strides, bcast_offset, 0, scratch,
+ materialized_output, materialized_input, materialized_input_size);
}
- }
-
- return TensorBlockV2(materialized_in_output
- ? internal::TensorBlockKind::kMaterializedInOutput
- : internal::TensorBlockKind::kMaterializedInScratch,
- materialized_output,
- desc.dimensions());
- }
-
- // This is a special case for `NumDims == 0`, in practice this should not
- // happen often, so it's fine to do memory allocation just for a scalar.
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
- scalarBlock(TensorBlockScratch& scratch) const {
- void* mem = scratch.allocate(sizeof(Scalar));
- ScalarNoConst* buf = static_cast<ScalarNoConst*>(mem);
- *buf = m_impl.coeff(0);
-
- DSizes<Index, NumDims> dimensions;
- for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
-
- return TensorBlockV2(internal::TensorBlockKind::kMaterializedInScratch, buf,
- dimensions);
- }
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 emptyBlock() const {
- DSizes<Index, NumDims> dimensions;
- for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
- return TensorBlockV2(internal::TensorBlockKind::kView, NULL, dimensions);
- }
-
- EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
-
- const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
- Broadcast functor() const { return m_broadcast; }
- #ifdef EIGEN_USE_SYCL
- // binding placeholder accessors to a command group handler for SYCL
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
- m_impl.bind(cgh);
- }
- #endif
- private:
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlock(
- const Dimensions& input_block_sizes,
- const BroadcastDimensions& broadcast_block_sizes,
- const BroadcastDimensions& broadcast_block_strides,
- const BroadcastDimensions& broadcast_tensor_strides, Index offset,
- TensorBlock* output_block) const {
- TensorBlock input_view_block(
- static_cast<int>(Layout) == static_cast<int>(ColMajor)
- ? indexColMajor(output_block->first_coeff_index() + offset)
- : indexRowMajor(output_block->first_coeff_index() + offset),
- input_block_sizes, Dimensions(m_inputStrides),
- Dimensions(m_inputStrides), NULL);
-
- internal::TensorBlockView<ArgType, Device> input_block(m_device, m_impl,
- input_view_block);
- BroadcastTensorBlock broadcast_block(
- 0, broadcast_block_sizes, broadcast_block_strides,
- broadcast_tensor_strides, output_block->data() + offset);
-
- BroadcastTensorBlockReader::Run(&broadcast_block, input_block.data());
+ return num_output_coeffs;
+ }
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlockV2(
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockV2(
const Dimensions& input_block_sizes,
const Dimensions& input_block_strides,
const BroadcastDimensions& bcast_block_sizes,
const BroadcastDimensions& bcast_block_strides,
- const BroadcastDimensions& bcast_input_strides, Index offset,
- const TensorBlockDesc& output_desc, TensorBlockScratch& scratch,
+ const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
+ Index offset, TensorBlockScratch& scratch,
ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
size_t* materialized_input_size) const {
// ---------------------------------------------------------------------- //
// Tensor block descriptor for reading block from the input.
- const Index input_offset = output_desc.offset() + offset;
- static const bool is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
- TensorBlockDesc input_desc(is_col_major
- ? indexColMajor(input_offset)
- : indexRowMajor(input_offset),
- input_block_sizes);
+ const Index input_offset = bcast_offset + offset;
+ TensorBlockDesc input_desc(
+ IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
+ input_block_sizes);
ArgTensorBlock input_block = m_impl.blockV2(input_desc, scratch);
@@ -1266,7 +1366,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
typename TensorBlockIOV2::Dst dst(bcast_block_sizes, bcast_block_strides,
materialized_output + offset);
- TensorBlockIOV2::Copy(dst, src);
+ return TensorBlockIOV2::Copy(dst, src);
}
protected: