diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/layout_assignment_test.cc | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 67e2cf6c77..a16fa75e30 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -141,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = Literal::CreateR2WithLayout<float>( + auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); - auto constant_literal2 = Literal::CreateR2WithLayout<float>( + auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); @@ -192,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // match their source). auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -229,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -240,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -274,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -584,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { auto builder = HloComputation::Builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -770,8 +770,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto token = - false_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = false_builder.AddInstruction(HloInstruction::CreateToken()); auto infeed = false_builder.AddInstruction( HloInstruction::CreateInfeed(xshape, token, "")); auto infeed_data = false_builder.AddInstruction( @@ -803,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); @@ -829,12 +828,14 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1} - ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, + token = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} - send = (f32[2,2], u32[]) send(gte), channel_id=1, + ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 + send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, sharding={maximal device=0} - send-done = () send-done(send), channel_id=1, sharding={maximal device=0} + send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } )"; @@ -853,7 +854,7 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { AssignLayouts(module.get(), &computation_layout, &channel_constraints); EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::GetSubshape( FindInstruction(module.get(), "send")->shape(), {0}), |