From 03b4161326897453fa6b2803b873954607f7623b Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 5 Oct 2018 11:49:19 -0700 Subject: [XLA] Extend the HLO verifier to check that non-layout-changing instructions preserve operand layouts. Add an std::function member to the HloVerifier for a backend to specify the function object used to determine whether an instruction can change layouts. Use the function object to find out the non-layout-changing instructions and check that such instructions should produce results with the same layouts as its operands. Add test cases. PiperOrigin-RevId: 215941282 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 9 ++- .../compiler/xla/service/gpu/nvptx_compiler.cc | 21 ++++--- tensorflow/compiler/xla/service/hlo_verifier.cc | 34 ++++++++++- tensorflow/compiler/xla/service/hlo_verifier.h | 14 ++++- .../compiler/xla/service/hlo_verifier_test.cc | 67 ++++++++++++++++++++++ tensorflow/compiler/xla/tests/hlo_test_base.cc | 14 +++-- tensorflow/compiler/xla/tests/hlo_test_base.h | 8 ++- 8 files changed, 149 insertions(+), 19 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4797cf3330..2b292ed053 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2450,6 +2450,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5834f67285..68c715a086 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -327,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - pass.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 50e47542c4..ac6c2c5565 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -239,8 +239,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -286,8 +288,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -299,7 +303,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline reduce_pipeline("reduce-precision"); reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -325,8 +330,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c22ee03388..fad3b14ec2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1042,7 +1042,10 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) { // not check result shape as that is checked in the ShapeVerifier. class InstructionVerifier : public DfsHloVisitorWithDefault { public: - InstructionVerifier() {} + explicit InstructionVerifier(std::function + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} Status DefaultAction(HloInstruction*) override { return Status::OK(); } @@ -1143,8 +1146,34 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); + } + } + } + + return Status::OK(); + } + private: absl::flat_hash_map instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace @@ -1158,7 +1187,8 @@ StatusOr HloVerifier::Run(HloModule* module) { std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); - InstructionVerifier instruction_verifier; + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6d16586c2c..cb49cb95ba 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -155,11 +155,17 @@ class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -177,6 +183,10 @@ class HloVerifier : public HloModulePass { // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c..afe01e5487 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d..7ab2ecda58 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104..217428befa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); ~HloTestBase() override {} -- cgit v1.2.3