aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc28
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 5e4df9ddd6..b332709995 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1135,8 +1135,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ tensorflow::gtl::ArraySlice<int64> limits) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
@@ -1159,13 +1158,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
int64 start_index = starts[dimension];
int64 limit_index = limits[dimension];
- int64 stride = strides[dimension];
if (start_index < 0) {
return InvalidArgument("negative start index to slice: %lld",
start_index);
}
- if (stride == 0) {
- return InvalidArgument("Zero stride");
+ if (limit_index < 0) {
+ return InvalidArgument("negative limit index to slice: %lld",
+ limit_index);
}
if (limit_index > arg.dimensions(dimension)) {
return InvalidArgument(
@@ -1173,21 +1172,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"size (%lld)",
limit_index, arg.dimensions(dimension));
}
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice",
+ limit_index, start_index);
+ }
VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
start_index);
VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
limit_index);
- if (stride > 0) {
- if (start_index > limit_index) {
- return InvalidArgument(
- "limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice with positive stride",
- limit_index, start_index);
- }
- sizes.push_back((limit_index - start_index + stride - 1) / stride);
- } else {
- return InvalidArgument("Negative strides not supported");
- }
+
+ sizes.push_back(limits[dimension] - starts[dimension]);
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);