aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-24 12:52:45 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-24 12:52:45 -0700
commitef9dfee7bdc8e0d82c9b7ddf9414ef99d866d7ba (patch)
tree490a8ae1f247cf226475f504ea1d3ab305b98097 /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
parentefd9867ff0e8df23016ac6c9828d0d7bf8bec1b1 (diff)
Tensor block evaluation V2 support for unary/binary/broadcsting
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h367
1 files changed, 367 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index b290de311..9e4fae99a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -115,6 +115,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
IsAligned = true,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
+ BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2,
PreferBlockAccess = true,
Layout = TensorEvaluator<ArgType, Device>::Layout,
RawAccess = false
@@ -131,11 +132,24 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
// We do block based broadcasting using a trick with 2x tensor rank and 0
// strides. See block method implementation for details.
typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
+
typedef internal::TensorBlock<ScalarNoConst, Index, 2 * NumDims, Layout>
BroadcastTensorBlock;
typedef internal::TensorBlockReader<ScalarNoConst, Index, 2 * NumDims, Layout>
BroadcastTensorBlockReader;
+ //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2
+ ArgTensorBlock;
+
+ typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
+ Layout, Index>
+ TensorBlockV2;
+ //===--------------------------------------------------------------------===//
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
const Device& device)
: isCopy(false), nByOne(false), oneByN(false),
@@ -867,6 +881,292 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
+ static const bool
+ is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
+
+ // Return a block with a single scalar.
+ if (NumDims <= 0) return scalarBlock(scratch);
+
+ // Because we only support kSkewedInnerDims blocking, block size should be
+ // equal to m_dimensions for inner dims, a smaller than m_dimensions[i] size
+ // for the first outer dim, and 1 for other outer dims. This is guaranteed
+ // by MergeResourceRequirements() in TensorBlock.h.
+ const Dimensions& output_dims = desc.dimensions();
+ const Dimensions output_strides = internal::strides<Layout>(output_dims);
+
+ // Find where outer dims start.
+ int outer_dim_start = 0;
+ Index outer_dim_size = 1;
+ Index inner_dim_size = 1;
+
+ for (int i = 0; i < NumDims; ++i) {
+ const int dim = is_col_major ? i : NumDims - i - 1;
+
+ if (i > outer_dim_start) {
+ eigen_assert(output_dims[dim] == 1);
+ } else if (output_dims[dim] != m_dimensions[dim]) {
+ eigen_assert(output_dims[dim] < m_dimensions[dim]);
+ outer_dim_size = output_dims[dim];
+ } else {
+ inner_dim_size *= output_dims[dim];
+ ++outer_dim_start;
+ }
+ }
+
+ if (inner_dim_size == 0 || outer_dim_size == 0) {
+ return emptyBlock();
+ }
+
+ const Dimensions& input_dims = Dimensions(m_impl.dimensions());
+
+ // Pre-fill input_block_sizes, broadcast_block_sizes,
+ // broadcast_block_strides, and broadcast_tensor_strides. Later on we will
+ // only modify the outer_dim_start-th dimension on these arrays.
+
+ // Calculate the input block size for looking into the input.
+ Dimensions input_block_sizes;
+ for (int i = 0; i < outer_dim_start; ++i) {
+ const int dim = is_col_major ? i : NumDims -i - 1;
+ input_block_sizes[dim] = input_dims[dim];
+ }
+ for (int i = outer_dim_start; i < NumDims; ++i) {
+ const int dim = is_col_major ? i : NumDims -i - 1;
+ input_block_sizes[dim] = 1;
+ }
+ Dimensions input_block_strides =
+ internal::strides<Layout>(input_block_sizes);
+
+ // Broadcast with the 0-stride trick: Create 1 extra dim for each
+ // broadcast, set the input stride to 0.
+ //
+ // When ColMajor:
+ //
+ // - bcast_block_sizes:
+ // [d_0, b_0, d_1, b_1, ...]
+ //
+ // - bcast_block_strides:
+ // [output_block_strides[0], output_block_strides[0] * d_0,
+ // output_block_strides[1], output_block_strides[1] * d_1,
+ // ...]
+ //
+ // - bcast_input_strides:
+ // [input_block_strides[0], 0,
+ // input_block_strides[1], 0,
+ // ...].
+ //
+ BroadcastDimensions bcast_block_sizes;
+ BroadcastDimensions bcast_block_strides;
+ BroadcastDimensions bcast_input_strides;
+
+ for (int i = 0; i < outer_dim_start; ++i) {
+ const int dim = is_col_major ? i : NumDims - i - 1;
+
+ const int copy_dim = is_col_major ? 2 * i : 2 * NumDims - 2 * i - 1;
+ const int broadcast_dim = is_col_major ? copy_dim + 1 : copy_dim - 1;
+
+ bcast_block_sizes[copy_dim] = input_dims[dim];
+ bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
+ bcast_block_strides[copy_dim] = output_strides[dim];
+ bcast_block_strides[broadcast_dim] =
+ output_strides[dim] * input_dims[dim];
+ bcast_input_strides[copy_dim] = input_block_strides[dim];
+ bcast_input_strides[broadcast_dim] = 0;
+ }
+ for (int i = 2 * outer_dim_start; i < 2 * NumDims; ++i) {
+ const int dim = is_col_major ? i : 2 * NumDims - i - 1;
+ bcast_block_sizes[dim] = 1;
+ bcast_block_strides[dim] = 0;
+ bcast_input_strides[dim] = 0;
+ }
+
+ const int outer_dim =
+ is_col_major ? outer_dim_start : NumDims - outer_dim_start - 1;
+
+ // Check if we can reuse `desc` destination, or allocate new scratch buffer.
+ ScalarNoConst* materialized_output =
+ desc.template destination<ScalarNoConst, Layout>();
+ bool materialized_in_output;
+
+ if (materialized_output != NULL) {
+ desc.DropDestinationBuffer();
+ materialized_in_output = true;
+
+ } else {
+ materialized_in_output = false;
+ const size_t materialized_output_size = desc.size() * sizeof(Scalar);
+ void* output_scratch_mem = scratch.allocate(materialized_output_size);
+ materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
+ }
+
+ size_t materialized_input_size = 0;
+ ScalarNoConst* materialized_input = NULL;
+
+ if (outer_dim_size == 1) {
+ // We just need one block read using the ready-set values above.
+ BroadcastBlockV2(
+ input_block_sizes, input_block_strides, bcast_block_sizes,
+ bcast_block_strides, bcast_input_strides, 0, desc, scratch,
+ materialized_output, &materialized_input, &materialized_input_size);
+
+ } else if (input_dims[outer_dim] == 1) {
+ // Broadcast outer_dim_start-th dimension (< NumDims) by outer_dim_size.
+ const int broadcast_outer_dim =
+ is_col_major ? 2 * outer_dim_start + 1
+ : 2 * NumDims - 2 * outer_dim_start - 2;
+
+ bcast_block_sizes[broadcast_outer_dim] = outer_dim_size;
+ bcast_input_strides[broadcast_outer_dim] = 0;
+ bcast_block_strides[broadcast_outer_dim] = output_strides[outer_dim];
+
+ BroadcastBlockV2(
+ input_block_sizes, input_block_strides, bcast_block_sizes,
+ bcast_block_strides, bcast_input_strides, 0, desc, scratch,
+ materialized_output, &materialized_input, &materialized_input_size);
+
+ } else {
+ // The general case. Let's denote the output block as x[...,
+ // a:a+outer_dim_size, :, ..., :], where a:a+outer_dim_size is a slice on
+ // the outer_dim_start-th dimension (< NumDims). We need to split the
+ // a:a+outer_dim_size into possibly 3 sub-blocks:
+ //
+ // (1) a:b, where b is the smallest multiple of
+ // input_dims[outer_dim_start] in [a, a+outer_dim_size].
+ //
+ // (2) b:c, where c is the largest multiple of input_dims[outer_dim_start]
+ // in [a, a+outer_dim_size].
+ //
+ // (3) c:a+outer_dim_size .
+ //
+ // Or, when b and c do not exist, we just need to process the whole block
+ // together.
+
+ // Find a.
+ const Index outer_dim_left_index =
+ desc.offset() / m_outputStrides[outer_dim];
+
+ // Find b and c.
+ const Index input_outer_dim_size = input_dims[outer_dim];
+
+ // First multiple after a. This is b when <= outer_dim_left_index +
+ // outer_dim_size.
+ const Index first_multiple =
+ divup<Index>(outer_dim_left_index, input_outer_dim_size) *
+ input_outer_dim_size;
+
+ if (first_multiple <= outer_dim_left_index + outer_dim_size) {
+ // b exists, so does c. Find it.
+ const Index last_multiple = (outer_dim_left_index + outer_dim_size) /
+ input_outer_dim_size * input_outer_dim_size;
+ const int copy_outer_dim = is_col_major
+ ? 2 * outer_dim_start
+ : 2 * NumDims - 2 * outer_dim_start - 1;
+ const int broadcast_outer_dim =
+ is_col_major ? 2 * outer_dim_start + 1
+ : 2 * NumDims - 2 * outer_dim_start - 2;
+
+ if (first_multiple > outer_dim_left_index) {
+ const Index head_size = first_multiple - outer_dim_left_index;
+ input_block_sizes[outer_dim] = head_size;
+ bcast_block_sizes[copy_outer_dim] = head_size;
+ bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
+ bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
+ bcast_block_sizes[broadcast_outer_dim] = 1;
+ bcast_input_strides[broadcast_outer_dim] = 0;
+ bcast_block_strides[broadcast_outer_dim] =
+ output_strides[outer_dim] * input_dims[outer_dim];
+
+ BroadcastBlockV2(input_block_sizes, input_block_strides,
+ bcast_block_sizes, bcast_block_strides,
+ bcast_input_strides, 0, desc, scratch,
+ materialized_output, &materialized_input,
+ &materialized_input_size);
+ }
+ if (first_multiple < last_multiple) {
+ input_block_sizes[outer_dim] = input_outer_dim_size;
+ bcast_block_sizes[copy_outer_dim] = input_outer_dim_size;
+ bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
+ bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
+ bcast_block_sizes[broadcast_outer_dim] =
+ (last_multiple - first_multiple) / input_outer_dim_size;
+ bcast_input_strides[broadcast_outer_dim] = 0;
+ bcast_block_strides[broadcast_outer_dim] =
+ output_strides[outer_dim] * input_dims[outer_dim];
+ const Index offset = (first_multiple - outer_dim_left_index) *
+ m_outputStrides[outer_dim];
+
+ BroadcastBlockV2(input_block_sizes, input_block_strides,
+ bcast_block_sizes, bcast_block_strides,
+ bcast_input_strides, offset, desc, scratch,
+ materialized_output, &materialized_input,
+ &materialized_input_size);
+ }
+ if (last_multiple < outer_dim_left_index + outer_dim_size) {
+ const Index tail_size =
+ outer_dim_left_index + outer_dim_size - last_multiple;
+ input_block_sizes[outer_dim] = tail_size;
+ bcast_block_sizes[copy_outer_dim] = tail_size;
+ bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
+ bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
+ bcast_block_sizes[broadcast_outer_dim] = 1;
+ bcast_input_strides[broadcast_outer_dim] = 0;
+ bcast_block_strides[broadcast_outer_dim] =
+ output_strides[outer_dim] * input_dims[outer_dim];
+ const Index offset = (last_multiple - outer_dim_left_index) *
+ m_outputStrides[outer_dim];
+
+ BroadcastBlockV2(input_block_sizes, input_block_strides,
+ bcast_block_sizes, bcast_block_strides,
+ bcast_input_strides, offset, desc, scratch,
+ materialized_output, &materialized_input,
+ &materialized_input_size);
+ }
+ } else {
+ // b and c do not exist.
+ const int copy_outer_dim = is_col_major
+ ? 2 * outer_dim_start
+ : 2 * NumDims - 2 * outer_dim_start - 1;
+ input_block_sizes[outer_dim] = outer_dim_size;
+ bcast_block_sizes[copy_outer_dim] = outer_dim_size;
+ bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
+ bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
+
+ BroadcastBlockV2(
+ input_block_sizes, input_block_strides, bcast_block_sizes,
+ bcast_block_strides, bcast_input_strides, 0, desc, scratch,
+ materialized_output, &materialized_input, &materialized_input_size);
+ }
+ }
+
+ return TensorBlockV2(materialized_in_output
+ ? internal::TensorBlockKind::kMaterializedInOutput
+ : internal::TensorBlockKind::kMaterializedInScratch,
+ materialized_output,
+ desc.dimensions());
+ }
+
+ // This is a special case for `NumDims == 0`, in practice this should not
+ // happen often, so it's fine to do memory allocation just for a scalar.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ scalarBlock(TensorBlockScratch& scratch) const {
+ void* mem = scratch.allocate(sizeof(Scalar));
+ ScalarNoConst* buf = static_cast<ScalarNoConst*>(mem);
+ *buf = m_impl.coeff(0);
+
+ DSizes<Index, NumDims> dimensions;
+ for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
+
+ return TensorBlockV2(internal::TensorBlockKind::kMaterializedInScratch, buf,
+ dimensions);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 emptyBlock() const {
+ DSizes<Index, NumDims> dimensions;
+ for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
+ return TensorBlockV2(internal::TensorBlockKind::kView, NULL, dimensions);
+ }
+
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
@@ -901,6 +1201,73 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
BroadcastTensorBlockReader::Run(&broadcast_block, input_block.data());
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlockV2(
+ const Dimensions& input_block_sizes,
+ const Dimensions& input_block_strides,
+ const BroadcastDimensions& bcast_block_sizes,
+ const BroadcastDimensions& bcast_block_strides,
+ const BroadcastDimensions& bcast_input_strides, Index offset,
+ const TensorBlockDesc& output_desc, TensorBlockScratch& scratch,
+ ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
+ size_t* materialized_input_size) const {
+ // ---------------------------------------------------------------------- //
+ // Tensor block descriptor for reading block from the input.
+ const Index input_offset = output_desc.offset() + offset;
+ static const bool is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
+ TensorBlockDesc input_desc(is_col_major
+ ? indexColMajor(input_offset)
+ : indexRowMajor(input_offset),
+ input_block_sizes);
+
+ ArgTensorBlock input_block = m_impl.blockV2(input_desc, scratch);
+
+ // ---------------------------------------------------------------------- //
+ // Materialize input block into a temporary memory buffer only if it's not
+ // already available in the arg block.
+ const ScalarNoConst* input_buffer = NULL;
+
+ if (input_block.data() != NULL) {
+ // Input block already has raw data, there is no need to materialize it.
+ input_buffer = input_block.data();
+
+ } else {
+ // Otherwise we have to do block assignment into a temporary buffer.
+
+ // Maybe reuse previously allocated buffer, or allocate a new one with a
+ // scratch allocator.
+ const size_t input_total_size = input_block_sizes.TotalSize();
+ if (*materialized_input == NULL ||
+ *materialized_input_size < input_total_size) {
+ *materialized_input_size = input_total_size;
+ void* mem = scratch.allocate(*materialized_input_size * sizeof(Scalar));
+ *materialized_input = static_cast<ScalarNoConst*>(mem);
+ }
+
+ typedef internal::TensorBlockAssignment<
+ ScalarNoConst, NumDims, typename ArgTensorBlock::XprType, Index>
+ TensorBlockAssignment;
+
+ typename TensorBlockAssignment::Dst assignment_dst(
+ input_block_sizes, input_block_strides, *materialized_input);
+
+ TensorBlockAssignment::Run(assignment_dst, input_block.expr());
+
+ input_buffer = *materialized_input;
+ }
+
+ // ---------------------------------------------------------------------- //
+ // Copy data from materialized input block to the materialized output, using
+ // given broadcast strides (strides with zeroes).
+ typedef internal::TensorBlockIOV2<ScalarNoConst, Index, 2 * NumDims, Layout>
+ TensorBlockIOV2;
+
+ typename TensorBlockIOV2::Src src(bcast_input_strides, input_buffer);
+ typename TensorBlockIOV2::Dst dst(bcast_block_sizes, bcast_block_strides,
+ materialized_output + offset);
+
+ TensorBlockIOV2::Copy(dst, src);
+ }
+
protected:
const Device EIGEN_DEVICE_REF m_device;
const typename internal::remove_reference<Broadcast>::type m_broadcast;