diff options
author | 2018-09-03 04:52:24 -0700 | |
---|---|---|
committer | 2018-09-03 04:57:16 -0700 | |
commit | 0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch) | |
tree | cc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/hlo_instructions.h | |
parent | 44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff) |
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the
ConvolutionFeatureGroupConverter pass and can instead set the group_count
parameter on the cudnn custom calls.
PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 45a648bbe4..3230383579 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1079,6 +1079,10 @@ class HloCustomCallInstruction : public HloInstruction { absl::make_unique<ConvolutionDimensionNumbers>(dnums); } const string& custom_call_target() const { return custom_call_target_; } + void set_feature_group_count(int64 feature_group_count) { + feature_group_count_ = feature_group_count; + } + int64 feature_group_count() const { return feature_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1099,6 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<Window> window_; // Describes the dimension numbers used for a convolution. std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; + // The number of feature groups. This is used for grouped convolutions. + int64 feature_group_count_; }; class HloPadInstruction : public HloInstruction { |