diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier_test.cc | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c..afe01e5487 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla |