aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/dot_operation_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc97
1 files changed, 8 insertions, 89 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index b72dd2707c..cf089d748d 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
{{1.0, 2.0}, {3.0, -4.0}},
- LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
+ MinorToMajorForIsRowMajor(lhs_row_major)))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
{{1.0, 6.0}, {7.0, -4.0}},
- LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
+ MinorToMajorForIsRowMajor(rhs_row_major)))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -277,62 +277,6 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) {
TestMatrixDot(260, 3, 520, false, false);
}
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) {
- TestMatrixDot(1, 8, 8, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) {
- TestMatrixDot(1, 130, 8, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) {
- TestMatrixDot(1, 8, 130, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) {
- TestMatrixDot(1, 290, 130, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) {
- TestMatrixDot(2, 1, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) {
- TestMatrixDot(8, 8, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) {
- TestMatrixDot(16, 1, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) {
- TestMatrixDot(16, 3, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) {
- TestMatrixDot(3, 3, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) {
- TestMatrixDot(29, 29, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) {
- TestMatrixDot(1, 8, 2, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) {
- TestMatrixDot(1, 2, 8, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) {
- TestMatrixDot(259, 258, 1, true, true);
-}
-
-XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) {
- TestMatrixDot(259, 258, 1, false, true);
-}
-
XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
constexpr bool kLhsRowMajor = false;
constexpr bool kRhsRowMajor = false;
@@ -362,15 +306,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
- LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
+ MinorToMajorForIsRowMajor(lhs_row_major)))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
- LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
+ MinorToMajorForIsRowMajor(rhs_row_major)))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -417,31 +361,6 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
TestNonsquareMatrixDot<complex64>();
}
-XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
- auto lhs_handle =
- client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
- {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
- .ConsumeValueOrDie();
- auto rhs_handle =
- client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
- {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
- LayoutUtil::MakeLayout({1, 0})))
- .ConsumeValueOrDie();
-
- ComputationBuilder builder(client_, TestName());
- auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
- auto result = builder.Dot(
- builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
- builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
-
- Array2D<complex64> expected({{30.0, -2.0}});
-
- ComputeAndCompareR2<complex64>(
- &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
-}
-
XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
ComputationBuilder builder(client_, TestName());
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});