diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/multioutput_fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/multioutput_fusion_test.cc | 120 |
1 files changed, 67 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 6597748c8d..eb06b115da 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase { const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(8.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f))); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape0, "0")); @@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShape(F32, {size, size})); expect.PopulateWithValue<float>(size * 1.5f * 3.5f); - auto actual = ExecuteAndTransfer( - std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1}); + auto actual = + ExecuteAndTransfer(std::move(hlo_module), + {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShape(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f})); + Literal expect = + std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned( - Literal::MakeTupleOwned( - Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)), - Literal::CreateR0<float>(1.0)), - Literal::MakeTupleOwned(Literal::CreateR0<float>(3.0), - Literal::CreateR0<int32>(4))); + auto param = LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), + LiteralUtil::CreateR0<float>(1.0)), + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0), + LiteralUtil::CreateR0<int32>(4))); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)), *result)); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -232,7 +234,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0, -1.0}); + auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result); @@ -265,7 +267,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0}); + auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result); @@ -308,12 +310,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}), - Literal::CreateR2<float>({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})), *result)); } @@ -338,12 +342,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2<float>({{6, 8}, {10, 12}}), - Literal::CreateR2<float>({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}), + LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})), *result)); } @@ -369,13 +375,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}), - Literal::CreateR1<float>({36, 64}), - Literal::CreateR1<float>({66, 138})), + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}), + LiteralUtil::CreateR1<float>({36, 64}), + LiteralUtil::CreateR1<float>({66, 138})), *result)); } @@ -401,14 +408,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), - Literal::CreateR2<float>({{3, 7}, {11, 15}}), - Literal::CreateR2<float>({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})), *result)); } @@ -434,14 +442,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2<float>({{6, 8}, {10, 12}}), - Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR2<float>({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}), + LiteralUtil::CreateR3<float>( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})), *result)); } @@ -468,14 +478,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR1<float>({14, 22}), - Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR3<float>( + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1<float>({14, 22}), + LiteralUtil::CreateR3<float>( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR3<float>( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), *result)); } @@ -502,15 +514,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - auto init1 = Literal::CreateR0<float>(5); - auto init2 = Literal::CreateR0<float>(6); + auto param = + LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = LiteralUtil::CreateR0<float>(5); + auto init2 = LiteralUtil::CreateR0<float>(6); std::unique_ptr<Literal> result = ExecuteNoHloPasses( std::move(module), {param.get(), init1.get(), init2.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2<float>({{167, 172}, {176, 180}}), - Literal::CreateR2<float>({{6, 6}, {6, 8}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}), + LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})), *result)); } @@ -537,19 +550,20 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3<Eigen::half>( + auto param = LiteralUtil::CreateR3<Eigen::half>( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); std::unique_ptr<Literal> result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2<float>({{3, 7}, {11, 15}}), - Literal::CreateR2<float>({{5, 16}, {36, 64}}), - Literal::CreateR3<Eigen::half>({{{Eigen::half(1), Eigen::half(2)}, - {Eigen::half(3), Eigen::half(4)}}, - {{Eigen::half(5), Eigen::half(6)}, - {Eigen::half(7), Eigen::half(8)}}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}), + LiteralUtil::CreateR3<Eigen::half>( + {{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), *result)); } |