diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-16 14:44:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 14:47:49 -0700 |
commit | d43820b9eff0cc863de2bbfb142afe92bf5afd00 (patch) | |
tree | e019bdc0bb52496436e0ff92d76b728165e3a420 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 8235e83c442744a1285ce97e5dfc2a6556f9f667 (diff) |
Improve gather ergonomics by renaming fields.
This CL renames the various inputs to the Gather HLO to be more mnemonic by
making it more obviously a batch dynamic-slice. The replacements I made are:
s/elided_window_dims/collapsed_slice_dims/g
s/window_bounds/slice_sizes/g
s/gather_dims_to_operand_dims/start_index_map/g
s/gather_indices/start_indices/g
s/output_window_dims/offset_dims/g
PiperOrigin-RevId: 209051067
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 269 |
1 files changed, 131 insertions, 138 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index a73fa181cd..4ed8fc6b86 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1654,11 +1654,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); @@ -1669,11 +1669,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{1}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1684,11 +1684,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{4}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1700,11 +1700,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) @@ -1717,11 +1717,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1735,11 +1735,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1749,16 +1749,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape( - f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3, 4}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{0, 1, 2, 3, 4}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) @@ -1772,11 +1771,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{0, 1, 2, 3}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/0), - /*window_bounds=*/{1, 30, 29, 28, 27})); + /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) @@ -1787,11 +1786,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) @@ -1802,11 +1801,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) @@ -1817,11 +1816,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) @@ -1833,11 +1832,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 8, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 8, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1850,11 +1849,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1867,14 +1866,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 99, 100, 101}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 99, 100, 101}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 2 in gather op is out of bounds")) + HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } @@ -1883,14 +1882,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 9}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 9}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 4 in gather op is out of bounds")) + HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } @@ -1899,16 +1898,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{4}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr("All components of the window index in a gather op must either " - "be a output window index or explicitly elided")) + HasSubstr("All components of the offset index in a gather op must either " + "be a offset dimension or explicitly collapsed")) << statusor.status(); } @@ -1917,14 +1916,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Invalid elided_window_dims set in gather op; valid " + HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } @@ -1934,16 +1933,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "Repeated dimensions not allowed in elided_window_dims in gather op")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Repeated dimensions not allowed in " + "collapsed_slice_dims in gather op")) << statusor.status(); } @@ -1952,17 +1950,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " - "the bound of dimension index_vector_dim=4 of " - "gather_indices is 5. These two numbers must be equal.")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in start_index_map and " + "the bound of dimension index_vector_dim=4 of " + "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1971,16 +1968,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " - "[0, 5), got: 4->7")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } @@ -1989,16 +1984,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } @@ -2007,14 +2001,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{2, 1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 1, 28, 27, 26}); + /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("elided_window_dims in gather op must be sorted")) + HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } @@ -2023,15 +2017,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{2}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 1, 300, 26}); + /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window bound at index 3 in gather op is out of range, " - "must be within [0, 48), got 300")) + HasSubstr("Slice size at index 3 in gather op is out of range, " + "must be within [0, 48), got 300.")) << statusor.status(); } @@ -2040,16 +2034,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26}); + /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Gather op must have one window bound for every input dimension")) + HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } @@ -2058,15 +2051,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26, 20}); + /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only elide window indices with bound 1, " - "but bound is 29 for index 1 at position 0")) + HasSubstr("Gather op can only collapse slice dims with bound 1, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } @@ -2074,16 +2067,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " - "rank(gather_indices) + 1)")) + "rank(start_indices) + 1)")) << statusor.status(); } |