aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc37
1 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 67e2cf6c77..a16fa75e30 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -141,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
- auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
- auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
@@ -192,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// match their source).
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -229,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -240,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateTuple({constant0, constant1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
@@ -274,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
// tuple and assigning the layouts of the copied arrays as needed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto inner_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
auto nested_tuple = builder.AddInstruction(
@@ -584,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, constant, {}));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -770,8 +770,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
// Using infeed as layout assignment does not mess up with it.
- auto token =
- false_builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
auto infeed = false_builder.AddInstruction(
HloInstruction::CreateInfeed(xshape, token, ""));
auto infeed_data = false_builder.AddInstruction(
@@ -803,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
builder.AddInstruction(HloInstruction::CreateUnary(
constant0->shape(), HloOpcode::kBitcast, constant0));
@@ -829,12 +828,14 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ENTRY entry_computation {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
- recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1}
- ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1,
+ token = token[] after-all()
+ recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1}
+ recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
sharding={maximal device=1}
- send = (f32[2,2], u32[]) send(gte), channel_id=1,
+ ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
+ send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1,
sharding={maximal device=0}
- send-done = () send-done(send), channel_id=1, sharding={maximal device=0}
+ send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
}
)";
@@ -853,7 +854,7 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
AssignLayouts(module.get(), &computation_layout, &channel_constraints);
EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::GetSubshape(
FindInstruction(module.get(), "send")->shape(), {0}),