aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.h
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/hlo_instructions.h
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 45a648bbe4..3230383579 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1079,6 +1079,10 @@ class HloCustomCallInstruction : public HloInstruction {
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
+ void set_feature_group_count(int64 feature_group_count) {
+ feature_group_count_ = feature_group_count;
+ }
+ int64 feature_group_count() const { return feature_group_count_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1099,6 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+ // The number of feature groups. This is used for grouped convolutions.
+ int64 feature_group_count_;
};
class HloPadInstruction : public HloInstruction {