diff options
author | 2018-09-04 11:17:30 -0700 | |
---|---|---|
committer | 2018-09-04 11:41:30 -0700 | |
commit | 5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 (patch) | |
tree | 79a4f6fcf270617fc56082702b0209240425ae8c /tensorflow/compiler/xla/service/cpu | |
parent | 9ae8214229960c634c9f82c00f2c0df287c27a9d (diff) |
[XLA] Make kConvolution, kDot HLO attributes mandatory
HLO transformations would forget to propagate the feature depth attribute.
Making these attributes mandatory, while slightly less convenient for tests,
makes HLO transformations more robust.
PiperOrigin-RevId: 211490160
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
3 files changed, 19 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 098ce17a56..2d9978404c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,9 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); - new_conv->set_precision_config(hlo->precision_config()); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696d..616c453750 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase { static constexpr int kOutputFeatureCount = 64; }; +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. @@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 284929ca07..6bd0a2dd90 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -38,7 +38,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { |