diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 28 |
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 5e4df9ddd6..b332709995 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1135,8 +1135,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr<Shape> ShapeInference::InferSliceShape( const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts, - tensorflow::gtl::ArraySlice<int64> limits, - tensorflow::gtl::ArraySlice<int64> strides) { + tensorflow::gtl::ArraySlice<int64> limits) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", @@ -1159,13 +1158,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension = 0; dimension < starts.size(); ++dimension) { int64 start_index = starts[dimension]; int64 limit_index = limits[dimension]; - int64 stride = strides[dimension]; if (start_index < 0) { return InvalidArgument("negative start index to slice: %lld", start_index); } - if (stride == 0) { - return InvalidArgument("Zero stride"); + if (limit_index < 0) { + return InvalidArgument("negative limit index to slice: %lld", + limit_index); } if (limit_index > arg.dimensions(dimension)) { return InvalidArgument( @@ -1173,21 +1172,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "size (%lld)", limit_index, arg.dimensions(dimension)); } + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice", + limit_index, start_index); + } VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); - if (stride > 0) { - if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); - } - sizes.push_back((limit_index - start_index + stride - 1) / stride); - } else { - return InvalidArgument("Negative strides not supported"); - } + + sizes.push_back(limits[dimension] - starts[dimension]); } return ShapeUtil::MakeShape(arg.element_type(), sizes); |