aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc225
1 files changed, 150 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index ab470f16a3..607bcdd51e 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -26,13 +26,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -89,7 +90,7 @@ class FusionTest : public HloTestBase {
HloInstruction* hlos[4];
for (int i = 0; i < Arity; ++i) {
hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D(operand_data[i])));
+ LiteralUtil::CreateR2FromArray2D(operand_data[i])));
}
auto answer_shape =
ShapeUtil::MakeShape(prim_type, {test_width, test_height});
@@ -115,7 +116,7 @@ class FusionTest : public HloTestBase {
ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
HloInstruction::FusionKind::kLoop);
- auto expected = Literal::CreateR2FromArray2D(answer_data);
+ auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
@@ -186,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
+ LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
- auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
+ LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+ auto const10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, false, true}, {false, true, false}})));
auto select11 = builder.AddInstruction(
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
HloOpcode::kSelect, const10, add8, const9));
@@ -222,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.5}, {2.72}}),
+ *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -233,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+ LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
// add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
@@ -248,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -269,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
auto hlo_module = CreateNewModule();
auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto x =
builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
auto y = builder.AddInstruction(
@@ -292,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
// add2 = broadcast(const_vector) + const_array
@@ -308,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -316,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto single_element_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), single_element_array));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -331,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -346,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -361,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -376,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -391,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -406,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -421,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -436,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -451,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
hlo_module->AddEntryComputation(builder.Build())
@@ -459,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -467,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -477,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -485,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {2}), const0, {}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -495,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -513,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1})));
auto dynamic_slice2 =
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {2}), const0, const1, {2}));
@@ -535,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -552,16 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -570,9 +572,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -590,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -602,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -610,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -624,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -632,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+ LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
Window window;
ASSERT_TRUE(
tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
@@ -674,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -686,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -710,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -764,6 +766,79 @@ XLA_TEST_F(FusionTest, Clamp2D) {
TestElementwise2D<float, 3>(HloOpcode::kClamp);
}
+// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast.
+XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) {
+ const string hlo_text = R"(
+HloModule Cluster
+
+fusion_c {
+ fusion.arg = f32[2,2]{1,0} parameter(0)
+ bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg)
+ tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0)
+ ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0)
+}
+
+ENTRY main {
+ arg = f32[2,2]{1,0} parameter(0)
+ ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c
+}
+)";
+
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ test_runner_.Execute(std::move(module), {operand.get()},
+ /*run_hlo_passes=*/false));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ *result));
+}
+
+class FusionClientLibraryTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
+ // On the GPU backend, it's possible to have too many transposes within one
+ // fusion, causing the kernel to run out shared memory and thus not compile.
+ // We want to check that doesn't happen.
+ //
+ // To do this, we create a computation that computes
+ //
+ // P0 + P0*P1*P1 + P0*P2*P2 ...
+ //
+ // where even parameters have layout 1 and odd parameters have layout 2.
+ //
+ // Our goal is to tempt the backend into creating one giant multi-output
+ // fusion for the whole computation, including the transposes. Currently
+ // multi-output fusion only fuses fusions, so each of the terms in the sum
+ // needs to be a fusion itself, thus the contortions above.
+ constexpr int kNumParams = 25;
+ XlaBuilder b("ManyLayoutTransformations");
+
+ // This test produces values that overflow int32, which is UB, so use uint32,
+ // where overflow is OK.
+ Array2D<uint32> arr(32, 32);
+ arr.FillUnique();
+ std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ LayoutUtil::MakeLayout({0, 1}));
+
+ std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ XlaOp p0 = AddParam(*l1, &b);
+ XlaOp sum = p0;
+ for (int i = 1; i < kNumParams; ++i) {
+ auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ sum = sum + p0 * pN * pN;
+ }
+
+ ComputeAndCompare(&b, {});
+}
+
void BM_ParallelFusion(int num_iters) {
// Simple element-wise computation to benchmark parallel task partitioning.
tensorflow::testing::StopTiming();
@@ -804,19 +879,19 @@ void BM_ParallelFusion(int num_iters) {
// Transfer literals to device.
auto param0_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
.ConsumeValueOrDie();