aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc158
1 files changed, 63 insertions, 95 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index 87ea4763f7..a29496dec4 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -196,11 +191,32 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
+
+ xla::Array2D<float> expected({
+ {0.5, 1.0, 1.5},
+ {0.41666667, 0.33333333, 0.25},
+ {0.23148148, 0.18518519, 0.13888889},
+ {0.16835017, 0.13468013, 0.1010101},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
+ xla::ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
+ xla::XlaBuilder builder(TestName());
+
+ xla::XlaOp a, b;
+ auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/3);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -219,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -242,11 +257,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -267,11 +281,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/true,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/true,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, complex64(0.08333333, 0.08333333),
@@ -295,11 +308,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, 1., 1.5},
@@ -317,49 +329,5 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
xla::ErrorSpec(1e-2, 1e-2));
}
-XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
- xla::XlaBuilder builder(TestName());
-
- xla::XlaOp a, b;
- auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
- auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
-
- xla::Array2D<float> expected({
- {0.5, 1.0, 1.5},
- {0.41666667, 0.33333333, 0.25},
- {0.23148148, 0.18518519, 0.13888889},
- {0.16835017, 0.13468013, 0.1010101},
- });
-
- ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
- xla::ErrorSpec(1e-2, 1e-2));
-}
-
-XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
- xla::XlaBuilder builder(TestName());
-
- xla::XlaOp a, b;
- auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
- auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
-
- xla::Array2D<float> expected({
- {0.5, 1.0, 1.5},
- {0.41666667, 0.33333333, 0.25},
- {0.23148148, 0.18518519, 0.13888889},
- {0.16835017, 0.13468013, 0.1010101},
- });
-
- ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
- xla::ErrorSpec(1e-2, 1e-2));
-}
-
} // namespace
} // namespace tensorflow