diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/indexed_array_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/indexed_array_analysis_test.cc | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index fc2befe05b..5f4b42799b 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -799,5 +799,170 @@ ENTRY main { AssertArrayForRootExpressionIs(hlo_text, "%add"); } +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_lhs = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[3,3] s32[3,3] { + { 70, 80, 90 }, + { 158, 184, 210 }, + { 246, 288, 330 } }) + %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + indices = s32[5] parameter(0) + dot_lhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,3] s32[4,3] { + { 84, 99, 114 }, + { 96, 114, 132 }, + { 108, 129, 150 }, + { 120, 144, 168 } }) + %indices 0->[1]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 38, 44, 50, 56 }, + { 83, 98, 113, 128 }, + { 128, 152, 176, 200 }, + { 173, 206, 239, 272 } }) + %indices 1->[1]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[5,3] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 14, 32, 50, 68 }, + { 32, 77, 122, 167 }, + { 50, 122, 194, 266 }, + { 68, 167, 266, 365 } }) + %indices 1->[0]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + indices = s32[4] parameter(0) + dot_rhs = s32[2,3,4] gather(gather_operand, indices), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={2,3,1} + ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_batch_dims={0}, rhs_batch_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[2,2,2] s32[2,2,2] { + { { 22, 28 }, + { 49, 64 } }, + { { 220, 244 }, + { 301, 334 } } }) + %indices 3->[2]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + indices = s32[2] parameter(0) + dot_lhs = s32[3,2] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot"); +} + } // namespace } // namespace xla |