diff options
author | 2017-07-05 07:37:35 -0700 | |
---|---|---|
committer | 2017-07-05 07:41:40 -0700 | |
commit | dd9549709072a489254a72d190d9e36632d58d95 (patch) | |
tree | 892187f1cac14ff4da5d6b0191bcf9ef3a4a184b /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | af23ae65db2585f4a18d0bc5f21f15e94805aa4f (diff) |
[XLA] Check for the correct number of strides in shape inference.
Add backward compatibility for serialized computations without strides.
PiperOrigin-RevId: 160956181
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 5e4df9ddd6..f02df232d8 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1149,6 +1149,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( starts.size(), limits.size()); } + if (starts.size() != strides.size()) { + return InvalidArgument("slice start and strides sizes differ: %zu vs %zu", + starts.size(), strides.size()); + } + if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( "slice index count does not match argument rank: %zu vs %lld", @@ -1164,9 +1169,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument("negative start index to slice: %lld", start_index); } - if (stride == 0) { - return InvalidArgument("Zero stride"); - } if (limit_index > arg.dimensions(dimension)) { return InvalidArgument( "limit index (%lld) must be less than or equal to dimension " @@ -1177,17 +1179,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); - if (stride > 0) { - if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); - } - sizes.push_back((limit_index - start_index + stride - 1) / stride); - } else { - return InvalidArgument("Negative strides not supported"); + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index); + } + if (stride <= 0) { + return InvalidArgument("stride (%lld) must be positive", stride); } + sizes.push_back((limit_index - start_index + stride - 1) / stride); } return ShapeUtil::MakeShape(arg.element_type(), sizes); |