aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-06-28 15:50:39 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-06-28 15:50:39 -0700
commit81a03bec75aac90aa343fccf6a7daf735e28c20d (patch)
tree487d4d27c6cd796b888246400c91540d87ba58dc
parent8053eeb51e1735f3956f49555ac3901388c2ccca (diff)
Fix TensorReverse on GPU with m_stride[i]==0
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h10
1 files changed, 5 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
index 6faff87bf..42205db31 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h
@@ -122,6 +122,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
RawAccess = false
};
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
typedef internal::TensorBlock<ScalarNoConst, Index, NumDims, Layout>
OutputTensorBlock;
@@ -141,17 +143,15 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
m_strides[0] = 1;
for (int i = 1; i < NumDims; ++i) {
m_strides[i] = m_strides[i-1] * m_dimensions[i-1];
+ if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]);
}
} else {
m_strides[NumDims-1] = 1;
for (int i = NumDims - 2; i >= 0; --i) {
m_strides[i] = m_strides[i+1] * m_dimensions[i+1];
+ if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]);
}
}
- // Remember the strides for fast division.
- for (int i = 0; i < NumDims; ++i) {
- m_fastStrides[i] = internal::TensorIntDivisor<Index>(m_strides[i]);
- }
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -377,7 +377,7 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
protected:
Dimensions m_dimensions;
array<Index, NumDims> m_strides;
- array<internal::TensorIntDivisor<Index>, NumDims> m_fastStrides;
+ array<IndexDivisor, NumDims> m_fastStrides;
TensorEvaluator<ArgType, Device> m_impl;
ReverseDimensions m_reverse;
const Device& m_device;