From 396a8a4105edd409d0821c4d5d0b920b315ffb72 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 8 Oct 2018 14:26:43 -0700 Subject: Add custom call with layout constraints. Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615 --- tensorflow/compiler/xla/service/hlo_parser_test.cc | 67 ++++++++++++++++++++++ 1 file changed, 67 insertions(+) (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc') diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 255123d331..17538c05bc 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -802,6 +802,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -2069,5 +2106,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); +} + +// custom call incompatible shape. + } // namespace } // namespace xla -- cgit v1.2.3