diff options
author | 2018-09-03 04:52:24 -0700 | |
---|---|---|
committer | 2018-09-03 04:57:16 -0700 | |
commit | 0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch) | |
tree | cc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/hlo_instruction.cc | |
parent | 44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff) |
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the
ConvolutionFeatureGroupConverter pass and can instead set the group_count
parameter on the cudnn custom calls.
PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index bd0b6af10d..6d13f85cbb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -385,6 +385,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( ->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); } + static_cast<HloCustomCallInstruction*>(instruction.get()) + ->set_feature_group_count( + std::max(static_cast<int64>(proto.feature_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -3269,7 +3272,15 @@ void HloInstruction::set_convolution_dimension_numbers( } int64 HloInstruction::feature_group_count() const { - return Cast<HloConvolutionInstruction>(this)->feature_group_count(); + if (auto convolution = DynCast<HloConvolutionInstruction>(this)) { + return convolution->feature_group_count(); + } + return Cast<HloCustomCallInstruction>(this)->feature_group_count(); +} + +void HloInstruction::set_feature_group_count(int64 feature_group_count) { + Cast<HloCustomCallInstruction>(this)->set_feature_group_count( + feature_group_count); } HloComputation* HloInstruction::select() const { |