diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/fusion_test.cc | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index dc64477935..607bcdd51e 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -799,6 +799,46 @@ ENTRY main { *result)); } +class FusionClientLibraryTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { + // On the GPU backend, it's possible to have too many transposes within one + // fusion, causing the kernel to run out shared memory and thus not compile. + // We want to check that doesn't happen. + // + // To do this, we create a computation that computes + // + // P0 + P0*P1*P1 + P0*P2*P2 ... + // + // where even parameters have layout 1 and odd parameters have layout 2. + // + // Our goal is to tempt the backend into creating one giant multi-output + // fusion for the whole computation, including the transposes. Currently + // multi-output fusion only fuses fusions, so each of the terms in the sum + // needs to be a fusion itself, thus the contortions above. + constexpr int kNumParams = 25; + XlaBuilder b("ManyLayoutTransformations"); + + // This test produces values that overflow int32, which is UB, so use uint32, + // where overflow is OK. + Array2D<uint32> arr(32, 32); + arr.FillUnique(); + std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({0, 1})); + + std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({1, 0})); + + XlaOp p0 = AddParam(*l1, &b); + XlaOp sum = p0; + for (int i = 1; i < kNumParams; ++i) { + auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + sum = sum + p0 * pN * pN; + } + + ComputeAndCompare(&b, {}); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming(); |