aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/custom_call_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/custom_call_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc50
1 files changed, 48 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index a693fa3595..001490c6a8 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
-XLA_TEST_F(CustomCallTest,
- DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
auto module = CreateNewModule();
auto b = HloComputation::Builder(TestName());
@@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest,
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+ b.AddInstruction(
+ HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues"));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ // Note, the expected result is transposed! This is because the input and
+ // output layouts of the custom call differ and the called function just
+ // blindly adds one to each element.
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
+}
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
+ // The argument and result of the computation are set to different layouts,
+ // but the custom call is layout constrained to a fixed operand and result
+ // layout, so the correct result should be produced.
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+
+ const Shape& r2f32_dim0_major =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
+ b.AddInstruction(HloInstruction::CreateCustomCall(
+ r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major}));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
+}
+
class CustomCallClientAPITest : public ClientLibraryTestBase {};
// When using the client API, CustomCall targets can't begin with '$' -- these