aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/inliner_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/inliner_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 7e967f035c..98e0f2cfd7 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
+TEST_F(InlinerTest, MapParameter) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto param_builder = HloComputation::Builder(TestName());
+ param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
+ param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
+ auto param_f32 = param_builder.Build();
+
+ auto builder = HloComputation::Builder("MapParamFunction");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = CreateNewVerifiedModule();
+ hlo_module->AddEmbeddedComputation(std::move(param_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ Inliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
+ auto expected = LiteralUtil::CreateR0<float>(4);
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
+}
} // namespace
} // namespace xla