diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 47 |
1 files changed, 18 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 6a09bb08f4..63303aef1e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, rhs_literal_data, - feature_group_count](absl::Span<const int64> out_index) { + feature_group_count](const absl::Span<const int64> out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + const int64 output_z_size = ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast<ElementwiseT>(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 rhs_iz = iz; - // Handle grouped convolutions. - if (feature_group_count > 1) { - // The size of a feature group. - int64 feature_group_size = z_size / feature_group_count; - rhs_iz = iz % feature_group_size; - - // The output feature dimension is a concatenation of convolution - // results from the different groups. - int64 output_feature_group_size = - output_z_size / feature_group_count; - - // Calculate the group index to which the current input feature - // index belongs. - int64 input_group_index = iz / feature_group_size; - - // Calculate the group index to which the current output index - // belongs. - int64 output_group_index = - out_index[output_z_dim] / output_feature_group_size; - if (input_group_index != output_group_index) { - // If the current output index does not belong to the current - // feature group, skip it. - continue; - } - } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * |