diff options
32 files changed, 387 insertions, 177 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index ebf21ac151..bb63ea26d4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -52,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloVerifiedTestBase { + public: + AlgebraicSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -2851,7 +2856,12 @@ struct DotOfConcatTestSpec { class DotOfConcatSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfConcatTestSpec> {}; + public ::testing::WithParamInterface<DotOfConcatTestSpec> { + public: + DotOfConcatSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that we transform // dot(const, concat(A, B, C)) @@ -3024,7 +3034,12 @@ struct DotOfGatherTestSpec { class DotOfGatherSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfGatherTestSpec> {}; + public ::testing::WithParamInterface<DotOfGatherTestSpec> { + public: + DotOfGatherSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // input: dot(DS(ctA), ctB)) // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a6..b342acb025 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -24,7 +24,12 @@ namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloVerifiedTestBase { + public: + BatchDotSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 49ae5320b0..b08705d4c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase { StatusOr<bool> result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167..6c477da038 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloVerifiedTestBase { public: + ConditionalSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); }; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index f8500e78b6..e01fecffd0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -88,6 +88,7 @@ cc_library( ":simple_orc_jit", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + ":target_machine_features", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ef71376cfb..279aa42fe2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -234,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_; const std::unordered_map<const HloInstruction*, int64>& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker<HloVerifier>(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<CpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( @@ -259,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass<BatchDotSimplification>(); pipeline.AddPass<DotDecomposer>(); pipeline.AddPass<ConvolutionFeatureGroupConverter>(); - pipeline.AddPass<ConvCanonicalization>(&target_machine_features); + pipeline.AddPass<ConvCanonicalization>(target_machine_features); { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, @@ -290,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass<IndexedArrayAnalysisPrinterPass>(); pipeline.AddPass<TransposeFolding>( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -308,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass<CpuLayoutAssignment>( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass<HloPassPipeline>("after layout assignment"); + after_layout_assn.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass<HloPassFix<AlgebraicSimplifier>>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -321,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass<HloDCE>(); pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true); } + pipeline.AddPass<HloElementTypeConverter>(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -334,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass<ParallelTaskAssigner>( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass<HloDCE>(); pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<CpuCopyInsertion>(); @@ -349,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 04e1c48872..47b5edabff 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 82c276d7ef..a84ee78b19 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -35,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb..37d1895d41 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class DefuserTest : public HloVerifiedTestBase { + public: + DefuserTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Returns the number of fusion instructions in the module. int FusionCount() { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index d9369f00cc..136b8e19aa 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -141,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<GpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -157,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -204,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. pipeline.AddPass<ConvolutionFeatureGroupConverter>(); pipeline.AddPass<CudnnConvolutionRewriter>(); @@ -219,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass<GpuLayoutAssignment>( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -267,7 +283,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix<HloPassPipeline> fusion("fusion"); - fusion.AddInvariantChecker<HloVerifier>(); + fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); fusion.AddPass<FusionMerger>(); @@ -277,7 +294,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker<HloVerifier>(); + reduce_pipeline.AddInvariantChecker<HloVerifier>( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -303,7 +321,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b82..104af48c82 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,12 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase { + public: + PadForTensorCoresTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0..15d1e269cc 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 9622936306..0f2d5568ca 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c5f3906356..40183de96e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 648228b825..79e78ee2d0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -29,6 +29,11 @@ namespace xla { namespace { class HloDomainTest : public HloVerifiedTestBase { + public: + HloDomainTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 139d70374f..7e85df53a3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,7 +52,10 @@ static std::array<bool, 2> use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { + HloEvaluatorTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique<HloEvaluator>(); } @@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase { + public: + HloEvaluatorPreciseReduceTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 66dd23e73f..f60c4eab42 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -123,29 +123,26 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( "Expected operand %lld to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + operand_no, StringifyShape(token->shape()).c_str(), instruction->ToString().c_str()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", operand->ToString().c_str(), parameter->ToString().c_str(), @@ -154,8 +151,6 @@ Status CheckOperandAndParameter(const HloInstruction* instruction, return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -172,13 +167,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), + StringifyShape(outfeed->operand(0)->shape()).c_str(), + StringifyShape(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); @@ -259,8 +253,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()).c_str(), + StringifyShape(sort->operand(1)->shape()).c_str()); } return CheckVariadicShape(sort); } @@ -334,7 +328,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %lld and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -416,12 +421,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape).c_str()); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -599,52 +603,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), + StringifyShape(inferred_shape).c_str(), + StringifyShape(instruction->shape()).c_str(), instruction->ToString().c_str()); } return Status::OK(); @@ -828,7 +831,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector<bool> parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { @@ -849,13 +852,6 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in " - "%s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { @@ -917,7 +913,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", HloOpcodeString(instruction->opcode()).c_str(), ShapeUtil::HumanString(out_shape).c_str(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 83b77c84eb..b6093d667c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -28,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -106,13 +106,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - // Return true if the shapes of the two operands have the same element type, - // and the result shape either has the same element type as the operand - // shapes or mixed precision is allowed and the result shape and the operand - // shapes have floating point element types. + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape); + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -125,14 +154,10 @@ class HloVerifier : public HloPassInterface { public: using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return absl::make_unique<ShapeVerifier>(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return absl::make_unique<ShapeVerifier>(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique<ShapeVerifier>(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -142,8 +167,7 @@ class HloVerifier : public HloPassInterface { ~HloVerifier() override = default; absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d764964f3c..70b741353d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -37,13 +37,15 @@ using ::testing::HasSubstr; class HloVerifierTest : public HloTestBase { public: HloVerifierTest() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; class HloVerifierTestAllowMixedPrecision : public HloTestBase { public: HloVerifierTestAllowMixedPrecision() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; TEST_F(HloVerifierTest, NullInstructionParent) { diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d522..df88587492 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + public: + ImplicitBroadcastRemoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 2d03aebc1a..c34c32f7d3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -22,6 +22,11 @@ limitations under the License. namespace xla { namespace { class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + public: + IndexedArrayAnalysisTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index fab4e797f3..a395dd5333 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -34,7 +34,12 @@ namespace { namespace op = xla::testing::opcode_matchers; -using ReshapeMoverTest = HloVerifiedTestBase; +class ReshapeMoverTest : public HloVerifiedTestBase { + public: + ReshapeMoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b..e14014b961 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers; class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { public: + WhileLoopInvariantCodeMotionTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 1c892ba179..cfe4104f6d 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -28,6 +28,11 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + WhileLoopSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 693454bd80..93ea144438 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) +HloTestBase::HloTestBase(bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = - absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier); + hlo_verifier_ = absl::make_unique<HloVerifier>( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); } std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index ce39dd78d4..06bcc39741 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test { // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); + HloTestBase(bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive = false, bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index dd130557b3..8f86c528d0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -25,8 +25,11 @@ limitations under the License. namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(absl::make_unique<ShapeVerifier>()) {} +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() { } void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr<bool> mutated = verifier.Run(module); + xla::StatusOr<bool> mutated = verifier().Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index a2f1185d63..cc6967feed 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -29,7 +29,8 @@ namespace xla { // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { protected: - HloVerifiedTestBase(); + explicit HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision); ~HloVerifiedTestBase() override; // Constructs a default shape verifier. @@ -47,29 +48,25 @@ class HloVerifiedTestBase : public HloTestBase { void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); - // Sets the shape-size function used during hlo verification. If this isn't - // called, a default ShapeVerifier is used instead. - void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) { - shape_verifier_ = std::move(shape_verifier); - } - // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. HloModule* CreateNewModule(const string& name = TestName()); + private: + void VerifyModule(HloModule* module); + // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. - private: + // // Lazily populated. Access via module(). std::unique_ptr<HloModule> module_; // Populated by calls to CreateNewModule. std::vector<std::unique_ptr<HloModule>> modules_; - std::unique_ptr<ShapeVerifier> shape_verifier_; + bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index adc0956164..16b77e965d 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -53,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } + // Layout assignment assumes that there are no fusions in the input graph. + // Since the purpose of this test is to send pre-fused graphs to XLA, we have + // to do layout assignment ourselves. + DebugOptions GetDebugOptionsForTest() override { + auto opts = HloTestBase::GetDebugOptionsForTest(); + opts.add_xla_disable_hlo_passes("layout-assignment"); + return opts; + } + void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); - const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + const Shape elem_shape2 = + ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f))); @@ -101,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue<float>(2.5f); - Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue<float>(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer(std::move(hlo_module), @@ -116,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); - const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + const Shape elem_shape_F32 = + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); + const Shape elem_shape_U8 = + ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( @@ -137,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {size, 1}), add)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, + dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -162,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0(ShapeUtil::MakeShape(F32, {size})); + Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); - Literal input1(ShapeUtil::MakeShape(F64, {size})); + Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); Literal expect = diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2f1d97b25d..21c58e075e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -408,8 +408,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( return std::move(arguments); } -Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { - return HloVerifier(allow_mixed_precision).Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision) { + return HloVerifier(/*layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision) + .Run(module) + .status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 1aca1d8ef7..277d53d423 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -95,8 +95,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(HloModule* const module, - bool allow_mixed_precision = false); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 870da6efa8..c7eb9e2dbe 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -66,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -83,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { "param")); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -100,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT(status.error_message(), ::testing::HasSubstr( |