aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-26 10:17:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 10:21:06 -0800
commit3b08cd35bc108f48b4f63d73af7a53eb8a1169f9 (patch)
tree4acb6d08b8978c51499073b25ec187e3f0e57fc1 /tensorflow/compiler/xla/service/hlo_instruction_test.cc
parentf4e70be18b104fbb2efeefeb83bea190aec12727 (diff)
Generalize the gather_indices dimension that stores indices
This is now exposed as a index_vector_dim dimension number. Also fixed an off-by-one error in ValidateGatherDimensionNumbers in the expression computing output_shape_rank. PiperOrigin-RevId: 187040748
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc43
1 files changed, 40 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 32d3ed272b..f2980d309d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1271,7 +1271,7 @@ TEST_F(HloInstructionTest, Stringification) {
"true_computation=%TransposeDot, false_computation=%TransposeDot");
}
-TEST_F(HloInstructionTest, StringifyGather) {
+TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
Shape gather_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
@@ -1291,7 +1291,8 @@ TEST_F(HloInstructionTest, StringifyGather) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26}));
HloModule module(TestName());
@@ -1303,7 +1304,43 @@ TEST_F(HloInstructionTest, StringifyGather) {
"s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
"output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
"gather_dims_to_operand_dims={0,1,2,3,4}, "
- "window_bounds={30,29,28,27,26}");
+ "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+}
+
+TEST_F(HloInstructionTest, StringifyGather_1) {
+ Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ Shape gather_indices_tensor_shape =
+ ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
+ Shape gather_result_shape =
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
+
+ HloComputation::Builder builder("Gather");
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+ HloInstruction* gather_indices =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, gather_indices_tensor_shape, "gather_indices"));
+
+ HloInstruction* gather_instruction =
+ builder.AddInstruction(HloInstruction::CreateGather(
+ gather_result_shape, input, gather_indices,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(gather_instruction->ToString(),
+ "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
+ "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+ "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
+ "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
+ "gather_dims_to_operand_dims={0,1,2,3,4}, "
+ "index_vector_dim=2, window_bounds={30,29,28,27,26}");
}
} // namespace