diff options
author | 2018-07-12 22:19:33 -0700 | |
---|---|---|
committer | 2018-07-12 22:23:34 -0700 | |
commit | 1b765165987f5277e294251c118f321166c70932 (patch) | |
tree | a3b879ec7d36389da1aa66a855aed762917140f8 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 385bb78761489cfa6d6808f239fe884152e71653 (diff) |
[XLA] Split out HloGatherInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 204421652
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 122 |
1 files changed, 63 insertions, 59 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index bafe14d6f4..9b1ce143c6 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <string> +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1543,45 +1544,45 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { }; TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1592,7 +1593,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1609,7 +1610,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1627,7 +1628,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1646,7 +1647,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3, 4}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1664,7 +1665,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3}, /*elided_window_dims=*/{0}, /*gather_dims_to_operand_dims=*/{0}, @@ -1679,10 +1680,11 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1693,10 +1695,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1707,10 +1710,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1722,7 +1726,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1739,7 +1743,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1756,7 +1760,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1772,7 +1776,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 9}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1788,7 +1792,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1806,7 +1810,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1823,7 +1827,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1841,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, @@ -1860,7 +1864,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, @@ -1878,7 +1882,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, @@ -1896,7 +1900,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1911,7 +1915,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1928,7 +1932,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1946,7 +1950,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1962,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, |