diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b2d12c94b8..a450dc6ff5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; |