diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/dot_operation_test.cc | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b3efba4f8..efa5aed2d1 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } +TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{105.0}, {105.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{105.0, 105.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{129.0}, {129.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{168.0}, {168.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla |