aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-28 14:34:42 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-28 14:34:42 -0700
commitbb13d5d917d8b4fadec24ab0f3465bbad0e6635f (patch)
tree543bdd70bc9f62be5cb299fb8e6a48036242eb91 /unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
parent104e8fa0747c4b53e0fbc4aacdd5de54cc861192 (diff)
Fix bug in copy optimization in Tensor slicing.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h53
1 files changed, 30 insertions, 23 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index d19aba3b3..213379dbd 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -979,44 +979,51 @@ struct TensorEvaluator<const TensorStridingSlicingOp<StartIndices, StopIndices,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
- : m_impl(op.expression(), device), m_device(device), m_strides(op.strides()), m_exprStartIndices(op.startIndices()), m_exprStopIndices(op.stopIndices())
+ : m_impl(op.expression(), device),
+ m_device(device),
+ m_strides(op.strides()), m_exprStartIndices(op.startIndices()),
+ m_exprStopIndices(op.stopIndices())
{
// Handle degenerate intervals by gracefully clamping and allowing m_dimensions to be zero
- DSizes<Index,NumDims> startIndicesClamped, stopIndicesClamped;
- m_is_identity = true;
- for (Index i = 0; i < internal::array_size<Dimensions>::value; ++i) {
- if (m_strides[i] != 1 || op.startIndices()[i] != 0 ||
- op.stopIndices()[i] != (m_impl.dimensions()[i] - 1)) {
- m_is_identity = false;
- }
-
+ DSizes<Index, NumDims> startIndicesClamped, stopIndicesClamped;
+ for (ptrdiff_t i = 0; i < internal::array_size<Dimensions>::value; ++i) {
eigen_assert(m_strides[i] != 0 && "0 stride is invalid");
- if(m_strides[i]>0){
- startIndicesClamped[i] = clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
- stopIndicesClamped[i] = clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
- }else{
- /* implies m_strides[i]<0 by assert */
- startIndicesClamped[i] = clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
- stopIndicesClamped[i] = clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
+ if (m_strides[i] > 0) {
+ startIndicesClamped[i] =
+ clamp(op.startIndices()[i], 0, m_impl.dimensions()[i]);
+ stopIndicesClamped[i] =
+ clamp(op.stopIndices()[i], 0, m_impl.dimensions()[i]);
+ } else {
+ /* implies m_strides[i] < 0 by assert */
+ startIndicesClamped[i] =
+ clamp(op.startIndices()[i], -1, m_impl.dimensions()[i] - 1);
+ stopIndicesClamped[i] =
+ clamp(op.stopIndices()[i], -1, m_impl.dimensions()[i] - 1);
}
m_startIndices[i] = startIndicesClamped[i];
}
- const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
+ typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
+ const InputDimensions& input_dims = m_impl.dimensions();
// check for degenerate intervals and compute output tensor shape
- bool degenerate = false;;
- for(int i = 0; i < NumDims; i++){
+ bool degenerate = false;
+ m_is_identity = true;
+ for (int i = 0; i < NumDims; i++) {
Index interval = stopIndicesClamped[i] - startIndicesClamped[i];
- if(interval == 0 || ((interval<0) != (m_strides[i]<0))){
+ if (interval == 0 || ((interval < 0) != (m_strides[i] < 0))) {
m_dimensions[i] = 0;
degenerate = true;
- }else{
- m_dimensions[i] = interval / m_strides[i]
- + (interval % m_strides[i] != 0 ? 1 : 0);
+ } else {
+ m_dimensions[i] =
+ (interval / m_strides[i]) + (interval % m_strides[i] != 0 ? 1 : 0);
eigen_assert(m_dimensions[i] >= 0);
}
+ if (m_strides[i] != 1 || interval != m_impl.dimensions()[i]) {
+ m_is_identity = false;
+ }
}
+
Strides output_dims = m_dimensions;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {