diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/gather_operation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/gather_operation_test.cc | 199 |
1 files changed, 118 insertions, 81 deletions
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b8404826b1..2008d69237 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -22,9 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -// NB! TODO(b/74360564): These tests do not test out of bounds behavior since -// that hasn't been specced yet. - namespace xla { namespace { @@ -63,8 +61,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -84,8 +83,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -105,9 +105,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -127,9 +127,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -149,9 +149,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -171,11 +171,11 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -195,11 +195,11 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -219,8 +219,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -240,9 +241,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -261,18 +262,15 @@ ENTRY main { window_bounds={1, 0} } )"; - std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { // Out of bounds indices must not crash, and the indices in range should // produce the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, - // we should get rid of the mask and check that backends produce the same - // value for OOB indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -286,29 +284,45 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>( + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr<Literal> in_bounds_mask = - Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { // Negative indices must not crash, and the indices in range should produce // the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for negative - // accesses, we should get rid of the mask and check that backends produce the - // same value for negative indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -322,20 +336,40 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>( + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr<Literal> in_bounds_mask = - Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = u32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = u32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = u32[6]{0} reshape(gather) +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -353,9 +387,9 @@ ENTRY main { window_bounds={1,3,2} } )"; - std::unique_ptr<Literal> operand = Literal::CreateR3<int32>( + std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -374,8 +408,8 @@ ENTRY main { window_bounds={1} } )"; - std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1); + std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -395,8 +429,8 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -419,8 +453,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -443,9 +478,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -468,9 +503,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -493,11 +528,11 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -521,11 +556,11 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -548,8 +583,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = + LiteralUtil::CreateR1<int32>({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -572,9 +608,9 @@ ENTRY main { } )"; std::unique_ptr<Literal> operand = - Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = - Literal::CreateR2<int32>({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -609,12 +645,13 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Gather(operand, indices, dim_numbers, {1, 3}); std::vector<int32> expected = {}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg, - client_->TransferToServer(*Literal::CreateR2<int32>( - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<GlobalData> operand_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> indices_arg, - client_->TransferToServer(*Literal::CreateR1<int32>({0, 2}))); + client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); |