diff options
author | 2018-09-05 01:00:25 -0700 | |
---|---|---|
committer | 2018-09-05 01:03:45 -0700 | |
commit | c7bd1589d08e84ca215b3c8c4dc3023986522ef7 (patch) | |
tree | 175a6f18ada9b911a482700e592728246db25ece /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | |
parent | 568b763776b7890570d9f6ab9568153329079958 (diff) |
Add support for grouped convolutions to the HloEvaluator.
Add a missing check to InferConvolveShape(), the output feature dimension needs to be divisible by feature_group_count.
Also fix some tests which took a const reference to the return value of
a function which doesn't return a reference.
PiperOrigin-RevId: 211592011
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 36 |
1 files changed, 34 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index dc16a84246..6a09bb08f4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1047,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data<ReturnT>(); auto rhs_literal_data = rhs_literal.data<ReturnT>(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](absl::Span<const int64> out_index) { + rhs_literal_data, + feature_group_count](absl::Span<const int64> out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1061,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_z_dim = dnums.output_feature_dimension(); const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); ElementwiseT result_val = static_cast<ElementwiseT>(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1069,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { + int64 rhs_iz = iz; + // Handle grouped convolutions. + if (feature_group_count > 1) { + // The size of a feature group. + int64 feature_group_size = z_size / feature_group_count; + rhs_iz = iz % feature_group_size; + + // The output feature dimension is a concatenation of convolution + // results from the different groups. + int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current input feature + // index belongs. + int64 input_group_index = iz / feature_group_size; + + // Calculate the group index to which the current output index + // belongs. + int64 output_group_index = + out_index[output_z_dim] / output_feature_group_size; + if (input_group_index != output_group_index) { + // If the current output index does not belong to the current + // feature group, skip it. + continue; + } + } + int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1077,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rhs_linear_index = 0; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { |