aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 14:49:26 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-27 14:49:26 -0700
commite95696acb313a84b33a18cc300de418b05dc58e5 (patch)
tree5aabd1314d1af823115a5d9b3bbc4d49172b36f4 /unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
parent9f33e71e9d33b51735841e40dfa49bda9d7fe5ff (diff)
Optimize TensorBlockCopyOp
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h86
1 files changed, 75 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
index 558130300..35523ec73 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
@@ -144,24 +144,88 @@ class TensorBlock {
template <typename Scalar, typename StorageIndex>
struct TensorBlockCopyOp {
+
+ typedef typename packet_traits<Scalar>::type Packet;
+ enum {
+ Vectorizable = internal::packet_traits<Scalar>::Vectorizable,
+ PacketSize = internal::packet_traits<Scalar>::size
+ };
+
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(
const StorageIndex num_coeff_to_copy, const StorageIndex dst_index,
const StorageIndex dst_stride, Scalar* EIGEN_RESTRICT dst_data,
const StorageIndex src_index, const StorageIndex src_stride,
const Scalar* EIGEN_RESTRICT src_data) {
- const Scalar* src_base = &src_data[src_index];
- Scalar* dst_base = &dst_data[dst_index];
-
- typedef const Array<Scalar, Dynamic, 1> Src;
- typedef Array<Scalar, Dynamic, 1> Dst;
+ const Scalar* src = &src_data[src_index];
+ Scalar* dst = &dst_data[dst_index];
- typedef Map<Src, 0, InnerStride<> > SrcMap;
- typedef Map<Dst, 0, InnerStride<> > DstMap;
-
- const SrcMap src(src_base, num_coeff_to_copy, InnerStride<>(src_stride));
- DstMap dst(dst_base, num_coeff_to_copy, InnerStride<>(dst_stride));
+ if (!Vectorizable) {
+ for (Index i = 0; i < num_coeff_to_copy; ++i) {
+ dst[i * dst_stride] = src[i * src_stride];
+ }
+ return;
+ }
- dst = src;
+ if (src_stride == 1) {
+ const Index vectorized_size = (num_coeff_to_copy / PacketSize) * PacketSize;
+ if (dst_stride == 1) {
+ // LINEAR
+ for (Index i = 0; i < vectorized_size; i += PacketSize) {
+ Packet p = internal::ploadu<Packet>(src + i);
+ internal::pstoreu<Scalar, Packet>(dst + i, p);
+ }
+ for (Index i = vectorized_size; i < num_coeff_to_copy; ++i) {
+ dst[i] = src[i];
+ }
+ } else {
+ // SCATTER
+ for (Index i = 0; i < vectorized_size; i += PacketSize) {
+ Packet p = internal::ploadu<Packet>(src + i);
+ internal::pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
+ }
+ for (Index i = vectorized_size; i < num_coeff_to_copy; ++i) {
+ dst[i * dst_stride] = src[i];
+ }
+ }
+ } else if (src_stride == 0) {
+ const Index vectorized_size = (num_coeff_to_copy / PacketSize) * PacketSize;
+ if (dst_stride == 1) {
+ // LINEAR
+ for (Index i = 0; i < vectorized_size; i += PacketSize) {
+ Packet p = internal::pload1<Packet>(src);
+ internal::pstoreu<Scalar, Packet>(dst + i, p);
+ }
+ for (Index i = vectorized_size; i < num_coeff_to_copy; ++i) {
+ dst[i] = *src;
+ }
+ } else {
+ // SCATTER
+ for (Index i = 0; i < vectorized_size; i += PacketSize) {
+ Packet p = internal::pload1<Packet>(src);
+ internal::pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
+ }
+ for (Index i = vectorized_size; i < num_coeff_to_copy; ++i) {
+ dst[i * dst_stride] = *src;
+ }
+ }
+ } else {
+ if (dst_stride == 1) {
+ // GATHER
+ const Index vectorized_size = (num_coeff_to_copy / PacketSize) * PacketSize;
+ for (Index i = 0; i < vectorized_size; i += PacketSize) {
+ Packet p = internal::pgather<Scalar, Packet>(src + i * src_stride, src_stride);
+ internal::pstoreu<Scalar, Packet>(dst + i, p);
+ }
+ for (Index i = vectorized_size; i < num_coeff_to_copy; ++i) {
+ dst[i] = src[i * src_stride];
+ }
+ } else {
+ // RANDOM
+ for (Index i = 0; i < num_coeff_to_copy; ++i) {
+ dst[i * dst_stride] = src[i * src_stride];
+ }
+ }
+ }
}
};