aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 255123d331..17538c05bc 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -804,6 +804,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
)"
},
+// CustomCallWithLayoutConstraints
+{
+"CustomCallWithLayoutConstraints",
+R"(HloModule CustomCallWithLayoutConstraints
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsNoOperands
+{
+"CustomCallWithLayoutConstraintsNoOperands",
+R"(HloModule CustomCallWithLayoutConstraintsNoOperands
+
+ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsTupleShapes
+{
+"CustomCallWithLayoutConstraintsTupleShapes",
+R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
+
+ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
+ %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
+}
+
+)"
+},
});
// clang-format on
}
@@ -2069,5 +2106,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
+ const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
+
+ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "Expected 2 operand layout constraints, 1 given");
+}
+
+TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
+ const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
+
+ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "operand 1 is not compatible with operand shape");
+}
+
+// custom call incompatible shape.
+
} // namespace
} // namespace xla