diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/dot_operation_test.cc | 97 |
1 files changed, 8 insertions, 89 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index b72dd2707c..cf089d748d 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<Element>( + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( {{1.0, 2.0}, {3.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) + MinorToMajorForIsRowMajor(lhs_row_major))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<Element>( + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( {{1.0, 6.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) + MinorToMajorForIsRowMajor(rhs_row_major))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -277,62 +277,6 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { TestMatrixDot(260, 3, 520, false, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { - TestMatrixDot(1, 8, 8, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { - TestMatrixDot(1, 130, 8, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { - TestMatrixDot(1, 8, 130, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { - TestMatrixDot(1, 290, 130, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { - TestMatrixDot(2, 1, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { - TestMatrixDot(8, 8, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { - TestMatrixDot(16, 1, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { - TestMatrixDot(16, 3, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { - TestMatrixDot(3, 3, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { - TestMatrixDot(29, 29, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { - TestMatrixDot(1, 8, 2, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { - TestMatrixDot(1, 2, 8, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { - TestMatrixDot(259, 258, 1, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { - TestMatrixDot(259, 258, 1, false, true); -} - XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { constexpr bool kLhsRowMajor = false; constexpr bool kRhsRowMajor = false; @@ -362,15 +306,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<Element>( + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) + MinorToMajorForIsRowMajor(lhs_row_major))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<Element>( + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) + MinorToMajorForIsRowMajor(rhs_row_major))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -417,31 +361,6 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { TestNonsquareMatrixDot<complex64>(); } -XLA_TEST_F(DotOperationTest, MatrixVectorC64) { - auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( - {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( - {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, - LayoutUtil::MakeLayout({1, 0}))) - .ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType<complex64>(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); - - Array2D<complex64> expected({{30.0, -2.0}}); - - ComputeAndCompareR2<complex64>( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { ComputationBuilder builder(client_, TestName()); auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}}); |