aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc107
1 files changed, 53 insertions, 54 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 69c7e42601..752a61476d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout,
@@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(module.get())
+ .Run(module)
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto module = CreateNewModule();
@@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
- module =
+ std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
- ->RunBackend(std::move(module),
+ ->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
@@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(&module(), &computation_layout);
- EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
@@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
HloComputation* computation = module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(module.get()).status();
+ Status error_status = layout_assignment.Run(module).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
@@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
@@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
- AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+ AssignLayouts(&module(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::GetSubshape(
- FindInstruction(module.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
@@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest,
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =