aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index b332709995..5e4df9ddd6 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1135,7 +1135,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits) {
+ tensorflow::gtl::ArraySlice<int64> limits,
+ tensorflow::gtl::ArraySlice<int64> strides) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
@@ -1158,13 +1159,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
int64 start_index = starts[dimension];
int64 limit_index = limits[dimension];
+ int64 stride = strides[dimension];
if (start_index < 0) {
return InvalidArgument("negative start index to slice: %lld",
start_index);
}
- if (limit_index < 0) {
- return InvalidArgument("negative limit index to slice: %lld",
- limit_index);
+ if (stride == 0) {
+ return InvalidArgument("Zero stride");
}
if (limit_index > arg.dimensions(dimension)) {
return InvalidArgument(
@@ -1172,18 +1173,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"size (%lld)",
limit_index, arg.dimensions(dimension));
}
- if (start_index > limit_index) {
- return InvalidArgument(
- "limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice",
- limit_index, start_index);
- }
VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
start_index);
VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
limit_index);
-
- sizes.push_back(limits[dimension] - starts[dimension]);
+ if (stride > 0) {
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice with positive stride",
+ limit_index, start_index);
+ }
+ sizes.push_back((limit_index - start_index + stride - 1) / stride);
+ } else {
+ return InvalidArgument("Negative strides not supported");
+ }
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);