aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/multioutput_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc120
1 files changed, 67 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 6597748c8d..eb06b115da 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase {
const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(8.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape0, "0"));
@@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
- auto actual = ExecuteAndTransfer(
- std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
+ auto actual =
+ ExecuteAndTransfer(std::move(hlo_module),
+ {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShape(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect =
+ std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)),
- Literal::CreateR0<float>(1.0)),
- Literal::MakeTupleOwned(Literal::CreateR0<float>(3.0),
- Literal::CreateR0<int32>(4)));
+ auto param = LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
+ LiteralUtil::CreateR0<float>(1.0)),
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
+ LiteralUtil::CreateR0<int32>(4)));
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)), *result));
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -232,7 +234,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
@@ -265,7 +267,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
@@ -308,12 +310,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
*result));
}
@@ -338,12 +342,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{6, 8}, {10, 12}}),
- Literal::CreateR2<float>({{25, 36}, {49, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
*result));
}
@@ -369,13 +375,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}),
- Literal::CreateR1<float>({36, 64}),
- Literal::CreateR1<float>({66, 138})),
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
*result));
}
@@ -401,14 +408,15 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
- Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
*result));
}
@@ -434,14 +442,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{6, 8}, {10, 12}}),
- Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
- Literal::CreateR2<float>({{25, 36}, {49, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
*result));
}
@@ -468,14 +478,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR1<float>({14, 22}),
- Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
- Literal::CreateR3<float>(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
*result));
}
@@ -502,15 +514,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- auto init1 = Literal::CreateR0<float>(5);
- auto init2 = Literal::CreateR0<float>(6);
+ auto param =
+ LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto init1 = LiteralUtil::CreateR0<float>(5);
+ auto init2 = LiteralUtil::CreateR0<float>(6);
std::unique_ptr<Literal> result = ExecuteNoHloPasses(
std::move(module), {param.get(), init1.get(), init2.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{167, 172}, {176, 180}}),
- Literal::CreateR2<float>({{6, 6}, {6, 8}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
+ LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
*result));
}
@@ -537,19 +550,20 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<Eigen::half>(
+ auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}}),
- Literal::CreateR3<Eigen::half>({{{Eigen::half(1), Eigen::half(2)},
- {Eigen::half(3), Eigen::half(4)}},
- {{Eigen::half(5), Eigen::half(6)},
- {Eigen::half(7), Eigen::half(8)}}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
+ LiteralUtil::CreateR3<Eigen::half>(
+ {{{Eigen::half(1), Eigen::half(2)},
+ {Eigen::half(3), Eigen::half(4)}},
+ {{Eigen::half(5), Eigen::half(6)},
+ {Eigen::half(7), Eigen::half(8)}}})),
*result));
}