aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-06 19:55:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:00:51 -0700
commit1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (patch)
treebd81ccee5d4722fb98e279af262ba6ae67789fe8 /tensorflow/compiler/xla/service/hlo_cost_analysis.cc
parentb0cd701121d63cacec498c2b097b0489fd529068 (diff)
[XLA] Add support for convolution feature groups to HloCostAnalysis
While there, tweak the implementation of convolution in the HLO evaluator to be a little simpler. PiperOrigin-RevId: 211911253
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 8b4eaad82e..a502fff9a0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -515,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
valid_position_counts.push_back(valid_position_count);
}
- const int64 fma_count =
- input_feature * output_feature * batch * Product(valid_position_counts);
+ const int64 fma_count = (input_feature / convolution->feature_group_count()) *
+ output_feature * batch *
+ Product(valid_position_counts);
current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return Status::OK();
}