aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-10-08 15:00:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 15:04:52 -0700
commit220c0f90af05ed1ca86831258888cc80757654fd (patch)
tree864eebf69d5f500150fe0df9bcde97320588a806 /tensorflow/compiler
parentb3bd7b378d00190fef831092836a5df62e39e7ed (diff)
[XLA] Simplify loop nesting in HandleConvolution
The calculation of a spatial coordinate in the kernel and activations is not dependent on which part of the contracted dimension (input feature) we are in. Rather than nesting the loops, the loops can be siblings: - One loop over spatial dimensions - One loop over the input feature group This reduces the nesting depth which makes the code a little more readable and might be slightly faster due work invariant in the spatial loop getting hoisted out. PiperOrigin-RevId: 216255839
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h96
1 files changed, 48 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index a450dc6ff5..84fbbd3e0c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1072,66 +1072,66 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
+ // Find corresponding spatial dimension index for input (lhs).
+ int64 lhs_linear_spatial_index = 0;
+ int64 rhs_linear_spatial_index = 0;
+ for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
+ // Spatial dimension number for input (lhs) and output.
+ const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
+ const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki);
+
+ // Calculate lhs (input) index without taking base dilation into
+ // account.
+ const auto& window_dim = window.dimensions(ki);
+ const int64 undilated_index =
+ out_index[output_spatial_dim] * window_dim.stride() -
+ window_dim.padding_low() +
+ rhs_spatial_index[ki] * window_dim.window_dilation();
+ // Skip if the lhs (input) index is to be dilated. As an
+ // optimization, skip this mod if there's no dilation.
+ if (window_dim.base_dilation() > 1 &&
+ undilated_index % window_dim.base_dilation() != 0) {
+ goto cnt;
+ }
+
+ // Calculate the actual lhs (input) index after dilation. As an
+ // optimization, skip this integer divide if there's no dilation.
+ int64 lhs_spatial_index;
+ if (window_dim.base_dilation() > 1) {
+ lhs_spatial_index = undilated_index / window_dim.base_dilation();
+ } else {
+ lhs_spatial_index = undilated_index;
+ }
+
+ // Skip if input index is not in bounds.
+ if (!(lhs_spatial_index >= 0 &&
+ lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) {
+ goto cnt;
+ }
+
+ lhs_linear_spatial_index +=
+ lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
+ rhs_linear_spatial_index +=
+ (window_dim.window_reversal()
+ ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+ : rhs_spatial_index[ki]) *
+ rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
+ }
+
for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
const int64 iz =
feature_group_index * input_feature_group_size + rhs_iz;
- int64 lhs_linear_index = 0;
+ int64 lhs_linear_index = lhs_linear_spatial_index;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
- int64 rhs_linear_index = 0;
+ int64 rhs_linear_index = rhs_linear_spatial_index;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
- // Find corresponding spatial dimension index for input (lhs).
- for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
- // Spatial dimension number for input (lhs) and output.
- const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
- const int64 output_spatial_dim =
- dnums.output_spatial_dimensions(ki);
-
- // Calculate lhs (input) index without taking base dilation into
- // account.
- const auto& window_dim = window.dimensions(ki);
- const int64 undilated_index =
- out_index[output_spatial_dim] * window_dim.stride() -
- window_dim.padding_low() +
- rhs_spatial_index[ki] * window_dim.window_dilation();
- // Skip if the lhs (input) index is to be dilated. As an
- // optimization, skip this mod if there's no dilation.
- if (window_dim.base_dilation() > 1 &&
- undilated_index % window_dim.base_dilation() != 0) {
- goto cnt;
- }
-
- // Calculate the actual lhs (input) index after dilation. As an
- // optimization, skip this integer divide if there's no dilation.
- int64 lhs_spatial_index;
- if (window_dim.base_dilation() > 1) {
- lhs_spatial_index = undilated_index / window_dim.base_dilation();
- } else {
- lhs_spatial_index = undilated_index;
- }
- lhs_linear_index +=
- lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
-
- // Skip if input index is not in bounds.
- if (!(lhs_spatial_index >= 0 &&
- lhs_spatial_index <
- lhs_shape.dimensions(input_spatial_dim))) {
- goto cnt;
- }
-
- rhs_linear_index +=
- (window_dim.window_reversal()
- ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
- : rhs_spatial_index[ki]) *
- rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
- }
-
result_val +=
static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);