diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/custom_call_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/custom_call_test.cc | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index a693fa3595..001490c6a8 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); @@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest, Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these |