diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-16 14:44:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 14:47:49 -0700 |
commit | d43820b9eff0cc863de2bbfb142afe92bf5afd00 (patch) | |
tree | e019bdc0bb52496436e0ff92d76b728165e3a420 /tensorflow/compiler/xla/service/hlo_instruction_test.cc | |
parent | 8235e83c442744a1285ce97e5dfc2a6556f9f667 (diff) |
Improve gather ergonomics by renaming fields.
This CL renames the various inputs to the Gather HLO to be more mnemonic by
making it more obviously a batch dynamic-slice. The replacements I made are:
s/elided_window_dims/collapsed_slice_dims/g
s/window_bounds/slice_sizes/g
s/gather_dims_to_operand_dims/start_index_map/g
s/gather_indices/start_indices/g
s/output_window_dims/offset_dims/g
PiperOrigin-RevId: 209051067
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 66 |
1 files changed, 32 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8a694dde80..504b13043f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) { TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); @@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=4, window_bounds={30,29,28,27,26}"); + "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyGather_1) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); @@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1419,10 +1417,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=2, window_bounds={30,29,28,27,26}"); + "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyScatter) { |