aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index dc64477935..607bcdd51e 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -799,6 +799,46 @@ ENTRY main {
*result));
}
+class FusionClientLibraryTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
+ // On the GPU backend, it's possible to have too many transposes within one
+ // fusion, causing the kernel to run out shared memory and thus not compile.
+ // We want to check that doesn't happen.
+ //
+ // To do this, we create a computation that computes
+ //
+ // P0 + P0*P1*P1 + P0*P2*P2 ...
+ //
+ // where even parameters have layout 1 and odd parameters have layout 2.
+ //
+ // Our goal is to tempt the backend into creating one giant multi-output
+ // fusion for the whole computation, including the transposes. Currently
+ // multi-output fusion only fuses fusions, so each of the terms in the sum
+ // needs to be a fusion itself, thus the contortions above.
+ constexpr int kNumParams = 25;
+ XlaBuilder b("ManyLayoutTransformations");
+
+ // This test produces values that overflow int32, which is UB, so use uint32,
+ // where overflow is OK.
+ Array2D<uint32> arr(32, 32);
+ arr.FillUnique();
+ std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ LayoutUtil::MakeLayout({0, 1}));
+
+ std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ XlaOp p0 = AddParam(*l1, &b);
+ XlaOp sum = p0;
+ for (int i = 1; i < kNumParams; ++i) {
+ auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ sum = sum + p0 * pN * pN;
+ }
+
+ ComputeAndCompare(&b, {});
+}
+
void BM_ParallelFusion(int num_iters) {
// Simple element-wise computation to benchmark parallel task partitioning.
tensorflow::testing::StopTiming();