aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc64
1 files changed, 22 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index f1bff6037b..a29496dec4 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -207,6 +207,28 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
xla::ErrorSpec(1e-2, 1e-2));
}
+XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
+ xla::XlaBuilder builder(TestName());
+
+ xla::XlaOp a, b;
+ auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/3);
+
+ xla::Array2D<float> expected({
+ {0.5, 1.0, 1.5},
+ {0.41666667, 0.33333333, 0.25},
+ {0.23148148, 0.18518519, 0.13888889},
+ {0.16835017, 0.13468013, 0.1010101},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
+ xla::ErrorSpec(1e-2, 1e-2));
+}
+
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
xla::XlaBuilder builder(TestName());
@@ -307,47 +329,5 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
xla::ErrorSpec(1e-2, 1e-2));
}
-XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
- xla::XlaBuilder builder(TestName());
-
- xla::XlaOp a, b;
- auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
- auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- TriangularSolveLeftLooking(a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
-
- xla::Array2D<float> expected({
- {0.5, 1.0, 1.5},
- {0.41666667, 0.33333333, 0.25},
- {0.23148148, 0.18518519, 0.13888889},
- {0.16835017, 0.13468013, 0.1010101},
- });
-
- ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
- xla::ErrorSpec(1e-2, 1e-2));
-}
-
-XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
- xla::XlaBuilder builder(TestName());
-
- xla::XlaOp a, b;
- auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
- auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- TriangularSolveLeftLooking(a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
-
- xla::Array2D<float> expected({
- {0.5, 1.0, 1.5},
- {0.41666667, 0.33333333, 0.25},
- {0.23148148, 0.18518519, 0.13888889},
- {0.16835017, 0.13468013, 0.1010101},
- });
-
- ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
- xla::ErrorSpec(1e-2, 1e-2));
-}
-
} // namespace
} // namespace tensorflow