aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-05 07:37:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-05 07:41:40 -0700
commitdd9549709072a489254a72d190d9e36632d58d95 (patch)
tree892187f1cac14ff4da5d6b0191bcf9ef3a4a184b /tensorflow/compiler/xla/service/shape_inference.cc
parentaf23ae65db2585f4a18d0bc5f21f15e94805aa4f (diff)
[XLA] Check for the correct number of strides in shape inference.
Add backward compatibility for serialized computations without strides. PiperOrigin-RevId: 160956181
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc27
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 5e4df9ddd6..f02df232d8 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1149,6 +1149,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
starts.size(), limits.size());
}
+ if (starts.size() != strides.size()) {
+ return InvalidArgument("slice start and strides sizes differ: %zu vs %zu",
+ starts.size(), strides.size());
+ }
+
if (starts.size() != ShapeUtil::Rank(arg)) {
return InvalidArgument(
"slice index count does not match argument rank: %zu vs %lld",
@@ -1164,9 +1169,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return InvalidArgument("negative start index to slice: %lld",
start_index);
}
- if (stride == 0) {
- return InvalidArgument("Zero stride");
- }
if (limit_index > arg.dimensions(dimension)) {
return InvalidArgument(
"limit index (%lld) must be less than or equal to dimension "
@@ -1177,17 +1179,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
start_index);
VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
limit_index);
- if (stride > 0) {
- if (start_index > limit_index) {
- return InvalidArgument(
- "limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice with positive stride",
- limit_index, start_index);
- }
- sizes.push_back((limit_index - start_index + stride - 1) / stride);
- } else {
- return InvalidArgument("Negative strides not supported");
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice with positive stride",
+ limit_index, start_index);
+ }
+ if (stride <= 0) {
+ return InvalidArgument("stride (%lld) must be positive", stride);
}
+ sizes.push_back((limit_index - start_index + stride - 1) / stride);
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);