aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-09 12:45:31 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-09 12:45:31 -0700
commit33e174613987cfc6c83576dc0fe8086c7a5d1b1f (patch)
tree4f4c62eab5c0feca0f233624c9c1fc571c491781 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parentf0a4642baba70a64128964d96c4ede012614925e (diff)
Block evaluation for TensorChipping + fixed bugs in TensorPadding and TensorSlicing
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h27
1 files changed, 16 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index b1d668744..b77d8fe84 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -53,18 +53,22 @@ struct TensorEvaluator
RawAccess = true
};
- typedef typename internal::TensorBlock<
- typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout>
+ typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
+
+ typedef typename internal::TensorBlock<ScalarNoConst, Index, NumCoords, Layout>
TensorBlock;
- typedef typename internal::TensorBlockReader<
- typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout>
+ typedef typename internal::TensorBlockReader<ScalarNoConst, Index, NumCoords, Layout>
TensorBlockReader;
- typedef typename internal::TensorBlockWriter<
- typename internal::remove_const<Scalar>::type, Index, NumCoords, Layout>
+ typedef typename internal::TensorBlockWriter<ScalarNoConst, Index, NumCoords, Layout>
TensorBlockWriter;
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords,
+ Layout, Index>
+ TensorBlockV2;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
@@ -161,6 +165,12 @@ struct TensorEvaluator
TensorBlockReader::Run(block, m_data);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
+ assert(m_data != NULL);
+ return TensorBlockV2::materialize(m_data, m_dims, desc, scratch);
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writeBlock(
const TensorBlock& block) {
assert(m_data != NULL);
@@ -269,11 +279,6 @@ struct TensorEvaluator<const Derived, Device>
typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc;
typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
- typedef internal::TensorBlockIOV2<ScalarNoConst, Index, NumCoords, Layout>
- TensorBlockIO;
- typedef typename TensorBlockIO::Dst TensorBlockIODst;
- typedef typename TensorBlockIO::Src TensorBlockIOSrc;
-
typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords,
Layout, Index>
TensorBlockV2;