aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h47
1 files changed, 18 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 6a09bb08f4..63303aef1e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
rhs_literal_data,
- feature_group_count](absl::Span<const int64> out_index) {
+ feature_group_count](const absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_batch_dim = dnums.output_batch_dimension();
const int64 output_z_dim = dnums.output_feature_dimension();
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 input_z_size =
+ ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ // The size of an input feature group.
+ const int64 input_feature_group_size = input_z_size / feature_group_count;
+
const int64 output_z_size =
ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
+ // The output feature dimension is a concatenation of convolution results
+ // from the different groups.
+ const int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ const int64 feature_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
- for (int64 iz = 0; iz < z_size; ++iz) {
- int64 rhs_iz = iz;
- // Handle grouped convolutions.
- if (feature_group_count > 1) {
- // The size of a feature group.
- int64 feature_group_size = z_size / feature_group_count;
- rhs_iz = iz % feature_group_size;
-
- // The output feature dimension is a concatenation of convolution
- // results from the different groups.
- int64 output_feature_group_size =
- output_z_size / feature_group_count;
-
- // Calculate the group index to which the current input feature
- // index belongs.
- int64 input_group_index = iz / feature_group_size;
-
- // Calculate the group index to which the current output index
- // belongs.
- int64 output_group_index =
- out_index[output_z_dim] / output_feature_group_size;
- if (input_group_index != output_group_index) {
- // If the current output index does not belong to the current
- // feature group, skip it.
- continue;
- }
- }
+ for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
+ const int64 iz =
+ feature_group_index * input_feature_group_size + rhs_iz;
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *