aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 8f0423bb1c..afe01e5487 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase {
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
+class HloVerifierTestLayoutSensitive : public HloTestBase {
+ public:
+ HloVerifierTestLayoutSensitive()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false,
+ LayoutAssignment::InstructionCanChangeLayout) {}
+};
+
TEST_F(HloVerifierTest, NullInstructionParent) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
HasSubstr("non-positive base area dilation factor"));
}
+static const char* const kAddWithLayoutChangeHlo = R"(
+ HloModule AddWithLayoutChange
+ ENTRY AddWithLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[3,4]{0,1} parameter(1)
+ ROOT add0 = f32[3,4]{1,0} add(par0,par1)
+ }
+ )";
+
+TEST_F(HloVerifierTest, AddWithLayoutChange) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
+ const char* const kSliceWithLayoutChangeHlo = R"(
+ HloModule SliceWithLayoutChange
+ ENTRY SliceWithLayoutChange {
+ par0 = f32[4,5]{0,1} parameter(0)
+ par1 = s32[2] parameter(1)
+ ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
+ dynamic_slice_sizes={3,4}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kSliceWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
+ const char* const kConcatWithLayoutChangeHlo = R"(
+ HloModule ConcatWithLayoutChange
+ ENTRY ConcatWithLayoutChange {
+ par0 = f32[3,5]{0,1} parameter(0)
+ par1 = f32[3,3]{1,0} parameter(1)
+ ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
+ dimensions={1}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kConcatWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
} // namespace
} // namespace xla