aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc190
1 files changed, 190 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 2c549cd872..ff6fdb5e4a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -65,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
}
+
+ void ExpectLayoutIs(const Shape& shape,
+ absl::Span<const int64> minor_to_major) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
+ << "Expected layout " << expected << ", actual " << shape.layout();
+ }
+
+ void ExpectTupleLayoutIs(
+ const Shape& shape,
+ std::initializer_list<absl::Span<const int64>> minor_to_majors) {
+ int i = 0;
+ for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
+ EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
+ << "Expected tuple element " << i << " layout " << expected
+ << ", actual " << actual;
+ ++i;
+ }
+ }
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
@@ -1102,5 +1123,174 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0));
}
+TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallNotLayoutConstrained
+
+ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
+ %p = f32[42,2,3] parameter(0)
+ ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
+ }
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
+ }
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrained
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ // The custom call should be partially encapsulated in kCopy instructions
+ // because of the layout mismatches.
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Copy(), op::Parameter())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
+ ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedZeroOperands
+
+ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall()));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleOperand
+
+ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Tuple())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleResult
+
+ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
+ %p0 = f32[4,4] parameter(0)
+ ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(),
+ {{1, 0}, {1, 0}});
+
+ const HloInstruction* custom_call =
+ FindInstruction(module.get(), "custom-call");
+ ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
+}
+
} // namespace
} // namespace xla