diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/layout_assignment_test.cc | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 2c549cd872..ff6fdb5e4a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -65,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector<int64>(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span<const int64> minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list<absl::Span<const int64>> minor_to_majors) { + int i = 0; + for (const absl::Span<const int64> minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -1102,5 +1123,174 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); } +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall())); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Tuple()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(module.get(), &computation_layout); + + ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(), + {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); +} + } // namespace } // namespace xla |