aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/hlo_instructions.cc
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc16
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 6871953755..e46afa764f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1660,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_window() = window_;
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1681,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other =
static_cast<const HloConvolutionInstruction&>(other);
+ if (feature_group_count_ != other.feature_group_count()) {
+ return false;
+ }
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
@@ -1793,8 +1797,8 @@ HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
- custom_call_target_(custom_call_target.begin(),
- custom_call_target.end()) {
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1810,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1824,6 +1829,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
"dim_labels=",
ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -1851,6 +1859,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.convolution_dimension_numbers()))) {
return false;
}
+ if (feature_group_count_ != casted_other.feature_group_count_) {
+ return false;
+ }
return custom_call_target_ == casted_other.custom_call_target_;
}
@@ -1866,6 +1877,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
+ cloned->set_feature_group_count(feature_group_count_);
return std::move(cloned);
}