diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-09-03 04:52:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-03 04:57:16 -0700 |
commit | 0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch) | |
tree | cc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff) |
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the
ConvolutionFeatureGroupConverter pass and can instead set the group_count
parameter on the cudnn custom calls.
PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 6871953755..e46afa764f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1660,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_window() = window_; *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1681,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast<const HloConvolutionInstruction&>(other); + if (feature_group_count_ != other.feature_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1793,8 +1797,8 @@ HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span<HloInstruction* const> operands, absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), - custom_call_target_(custom_call_target.begin(), - custom_call_target.end()) { + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); } @@ -1810,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1824,6 +1829,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl( "dim_labels=", ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1851,6 +1859,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.convolution_dimension_numbers()))) { return false; } + if (feature_group_count_ != casted_other.feature_group_count_) { + return false; + } return custom_call_target_ == casted_other.custom_call_target_; } @@ -1866,6 +1877,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + cloned->set_feature_group_count(feature_group_count_); return std::move(cloned); } |