diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/dot_operation_test.cc | 121 |
1 files changed, 66 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 33d79aebb1..d86fd7cc2d 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(), - Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(), + LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *Literal::CreateR2<float>({{19, 22}, {43, 50}}), + *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -194,11 +195,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D<T>( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D<T>( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -217,14 +218,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,9 +287,10 @@ void ParametricDotTest::TestImpl() { std::unique_ptr<Array2D<NativeT>> dot_lhs_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k); - std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr<Literal> dot_lhs_lit = + LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( + param.dot_lhs_row_major))); std::unique_ptr<GlobalData> dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); @@ -297,7 +299,7 @@ void ParametricDotTest::TestImpl() { Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); std::unique_ptr<Literal> dot_rhs_lit = - Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); + LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr<GlobalData> dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); @@ -307,7 +309,7 @@ void ParametricDotTest::TestImpl() { if (param.has_addend) { addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n); - addend_lit = Literal::CreateR2FromArray2DWithLayout( + addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); @@ -476,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -510,12 +512,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -583,7 +585,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D<T>( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -591,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D<T>( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -629,13 +631,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D<T>( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D<T>( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -664,15 +666,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout<T>( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout<T>( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); XlaBuilder builder(this->TestName()); @@ -733,15 +737,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array))); Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2<T>( @@ -782,15 +786,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D<T>(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array))); Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2<T>( @@ -853,10 +857,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -883,10 +886,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -913,10 +913,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -948,10 +945,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -983,10 +977,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1010,10 +1001,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1036,5 +1024,28 @@ XLA_TEST_F(DotOperationTest, Array2D<float> expected({{168.0}, {168.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D<float> expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla |