aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/dot_operation_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc245
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 6b3efba4f8..efa5aed2d1 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
this->error_spec_);
}
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{96.0, 105.0, 114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{105.0}, {105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{105.0, 105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{96.0}, {105.0}, {114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{126.0, 129.0, 132.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{129.0}, {129.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{56.0, 168.0, 91.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{168.0}, {168.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla