aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h96
1 files changed, 48 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index a450dc6ff5..84fbbd3e0c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1072,66 +1072,66 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
+ // Find corresponding spatial dimension index for input (lhs).
+ int64 lhs_linear_spatial_index = 0;
+ int64 rhs_linear_spatial_index = 0;
+ for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
+ // Spatial dimension number for input (lhs) and output.
+ const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
+ const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki);
+
+ // Calculate lhs (input) index without taking base dilation into
+ // account.
+ const auto& window_dim = window.dimensions(ki);
+ const int64 undilated_index =
+ out_index[output_spatial_dim] * window_dim.stride() -
+ window_dim.padding_low() +
+ rhs_spatial_index[ki] * window_dim.window_dilation();
+ // Skip if the lhs (input) index is to be dilated. As an
+ // optimization, skip this mod if there's no dilation.
+ if (window_dim.base_dilation() > 1 &&
+ undilated_index % window_dim.base_dilation() != 0) {
+ goto cnt;
+ }
+
+ // Calculate the actual lhs (input) index after dilation. As an
+ // optimization, skip this integer divide if there's no dilation.
+ int64 lhs_spatial_index;
+ if (window_dim.base_dilation() > 1) {
+ lhs_spatial_index = undilated_index / window_dim.base_dilation();
+ } else {
+ lhs_spatial_index = undilated_index;
+ }
+
+ // Skip if input index is not in bounds.
+ if (!(lhs_spatial_index >= 0 &&
+ lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) {
+ goto cnt;
+ }
+
+ lhs_linear_spatial_index +=
+ lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
+ rhs_linear_spatial_index +=
+ (window_dim.window_reversal()
+ ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+ : rhs_spatial_index[ki]) *
+ rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
+ }
+
for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
const int64 iz =
feature_group_index * input_feature_group_size + rhs_iz;
- int64 lhs_linear_index = 0;
+ int64 lhs_linear_index = lhs_linear_spatial_index;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
- int64 rhs_linear_index = 0;
+ int64 rhs_linear_index = rhs_linear_spatial_index;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
- // Find corresponding spatial dimension index for input (lhs).
- for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
- // Spatial dimension number for input (lhs) and output.
- const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
- const int64 output_spatial_dim =
- dnums.output_spatial_dimensions(ki);
-
- // Calculate lhs (input) index without taking base dilation into
- // account.
- const auto& window_dim = window.dimensions(ki);
- const int64 undilated_index =
- out_index[output_spatial_dim] * window_dim.stride() -
- window_dim.padding_low() +
- rhs_spatial_index[ki] * window_dim.window_dilation();
- // Skip if the lhs (input) index is to be dilated. As an
- // optimization, skip this mod if there's no dilation.
- if (window_dim.base_dilation() > 1 &&
- undilated_index % window_dim.base_dilation() != 0) {
- goto cnt;
- }
-
- // Calculate the actual lhs (input) index after dilation. As an
- // optimization, skip this integer divide if there's no dilation.
- int64 lhs_spatial_index;
- if (window_dim.base_dilation() > 1) {
- lhs_spatial_index = undilated_index / window_dim.base_dilation();
- } else {
- lhs_spatial_index = undilated_index;
- }
- lhs_linear_index +=
- lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
-
- // Skip if input index is not in bounds.
- if (!(lhs_spatial_index >= 0 &&
- lhs_spatial_index <
- lhs_shape.dimensions(input_spatial_dim))) {
- goto cnt;
- }
-
- rhs_linear_index +=
- (window_dim.window_reversal()
- ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
- : rhs_spatial_index[ki]) *
- rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
- }
-
result_val +=
static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);