aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 12:44:06 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 12:44:06 -0700
commit60ae24ee1a6c16114de456d77fcfba6f5a1160ca (patch)
tree7b9d5463018055571a5050ca31a8d3df12a3e6fc /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent6e40454a6e6cc57c07c7340148657c985ca6c928 (diff)
Add block evaluation to TensorReshaping/TensorCasting/TensorPadding/TensorSelect
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h138
1 files changed, 68 insertions, 70 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index c87075a72..b1d668744 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -176,11 +176,12 @@ struct TensorEvaluator
typedef internal::TensorBlockAssignment<Scalar, NumCoords, TensorBlockExpr,
Index>
TensorBlockAssign;
- typename TensorBlockAssign::Dst dst(desc.dimensions(),
- internal::strides<Layout>(m_dims),
- m_data, desc.offset());
- TensorBlockAssign::Run(dst, block.expr());
+ TensorBlockAssign::Run(
+ TensorBlockAssign::target(desc.dimensions(),
+ internal::strides<Layout>(m_dims), m_data,
+ desc.offset()),
+ block.expr());
}
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_data; }
@@ -349,62 +350,7 @@ struct TensorEvaluator<const Derived, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
assert(m_data != NULL);
-
- // TODO(ezhulenev): Move it to TensorBlockV2 and reuse in TensorForcedEval.
-
- // If a tensor block descriptor covers a contiguous block of the underlying
- // memory, we can skip block buffer memory allocation, and construct a block
- // from existing `m_data` memory buffer.
- //
- // Example: (RowMajor layout)
- // m_dims: [11, 12, 13, 14]
- // desc.dimensions(): [1, 1, 3, 14]
- //
- // In this case we can construct a TensorBlock starting at
- // `m_data + desc.offset()`, with a `desc.dimensions()` block sizes.
-
- static const bool
- is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
-
- // Find out how many inner dimensions have a matching size.
- int num_matching_inner_dims = 0;
- for (int i = 0; i < NumCoords; ++i) {
- int dim = is_col_major ? i : NumCoords - i - 1;
- if (m_dims[dim] != desc.dimensions()[dim]) break;
- ++num_matching_inner_dims;
- }
-
- // All the outer dimensions must be of size `1`, except a single dimension
- // before the matching inner dimension (`3` in the example above).
- bool can_use_direct_access = true;
- for (int i = num_matching_inner_dims + 1; i < NumCoords; ++i) {
- int dim = is_col_major ? i : NumCoords - i - 1;
- if (desc.dimension(dim) != 1) {
- can_use_direct_access = false;
- break;
- }
- }
-
- if (can_use_direct_access) {
- EvaluatorPointerType block_start = m_data + desc.offset();
- return TensorBlockV2(internal::TensorBlockKind::kView, block_start,
- desc.dimensions());
-
- } else {
- void* mem = scratch.allocate(desc.size() * sizeof(Scalar));
- ScalarNoConst* block_buffer = static_cast<ScalarNoConst*>(mem);
-
- TensorBlockIOSrc src(internal::strides<Layout>(m_dims), m_data,
- desc.offset());
- TensorBlockIODst dst(desc.dimensions(),
- internal::strides<Layout>(desc.dimensions()),
- block_buffer);
-
- TensorBlockIO::Copy(dst, src);
-
- return TensorBlockV2(internal::TensorBlockKind::kMaterializedInScratch,
- block_buffer, desc.dimensions());
- }
+ return TensorBlockV2::materialize(m_data, m_dims, desc, scratch);
}
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_data; }
@@ -923,15 +869,21 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
typedef typename XprType::Scalar Scalar;
enum {
- IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
- PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &
- PacketType<Scalar, Device>::HasBlend,
- BlockAccess = false,
- BlockAccessV2 = false,
- PreferBlockAccess = false,
- Layout = TensorEvaluator<IfArgType, Device>::Layout,
- CoordAccess = false, // to be implemented
- RawAccess = false
+ IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
+ TensorEvaluator<ElseArgType, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess &
+ TensorEvaluator<ElseArgType, Device>::PacketAccess &
+ PacketType<Scalar, Device>::HasBlend,
+ BlockAccess = false,
+ BlockAccessV2 = TensorEvaluator<IfArgType, Device>::BlockAccessV2 &&
+ TensorEvaluator<ThenArgType, Device>::BlockAccessV2 &&
+ TensorEvaluator<ElseArgType, Device>::BlockAccessV2,
+ PreferBlockAccess = TensorEvaluator<IfArgType, Device>::PreferBlockAccess ||
+ TensorEvaluator<ThenArgType, Device>::PreferBlockAccess ||
+ TensorEvaluator<ElseArgType, Device>::PreferBlockAccess,
+ Layout = TensorEvaluator<IfArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
@@ -953,8 +905,36 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
typedef StorageMemory<CoeffReturnType, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
+ static const int NumDims = internal::array_size<Dimensions>::value;
+
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
- typedef internal::TensorBlockNotImplemented TensorBlockV2;
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const IfArgType, Device>::TensorBlockV2
+ IfArgTensorBlock;
+ typedef typename TensorEvaluator<const ThenArgType, Device>::TensorBlockV2
+ ThenArgTensorBlock;
+ typedef typename TensorEvaluator<const ElseArgType, Device>::TensorBlockV2
+ ElseArgTensorBlock;
+
+ struct TensorSelectOpBlockFactory {
+ template <typename IfArgXprType, typename ThenArgXprType, typename ElseArgXprType>
+ struct XprType {
+ typedef TensorSelectOp<const IfArgXprType, const ThenArgXprType, const ElseArgXprType> type;
+ };
+
+ template <typename IfArgXprType, typename ThenArgXprType, typename ElseArgXprType>
+ typename XprType<IfArgXprType, ThenArgXprType, ElseArgXprType>::type expr(
+ const IfArgXprType& if_expr, const ThenArgXprType& then_expr, const ElseArgXprType& else_expr) const {
+ return typename XprType<IfArgXprType, ThenArgXprType, ElseArgXprType>::type(if_expr, then_expr, else_expr);
+ }
+ };
+
+ typedef internal::TensorTernaryExprBlock<TensorSelectOpBlockFactory,
+ IfArgTensorBlock, ThenArgTensorBlock,
+ ElseArgTensorBlock>
+ TensorBlockV2;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
@@ -1000,6 +980,24 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
.cwiseMax(m_elseImpl.costPerCoeff(vectorized));
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void getResourceRequirements(
+ std::vector<internal::TensorOpResourceRequirements>* resources) const {
+ m_condImpl.getResourceRequirements(resources);
+ m_thenImpl.getResourceRequirements(resources);
+ m_elseImpl.getResourceRequirements(resources);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
+ // It's unsafe to pass destination buffer to underlying expressions, because
+ // output might be aliased with one of the inputs.
+ desc.DropDestinationBuffer();
+
+ return TensorBlockV2(
+ m_condImpl.blockV2(desc, scratch), m_thenImpl.blockV2(desc, scratch),
+ m_elseImpl.blockV2(desc, scratch), TensorSelectOpBlockFactory());
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return NULL; }
#ifdef EIGEN_USE_SYCL