diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc | 64 |
1 files changed, 22 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index f1bff6037b..a29496dec4 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -207,6 +207,28 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b; + auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/3); + + xla::Array2D<float> expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaBuilder builder(TestName()); @@ -307,47 +329,5 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - TriangularSolveLeftLooking(a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - - xla::Array2D<float> expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - TriangularSolveLeftLooking(a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - - xla::Array2D<float> expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - } // namespace } // namespace tensorflow |