aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-26 10:17:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 10:21:06 -0800
commit3b08cd35bc108f48b4f63d73af7a53eb8a1169f9 (patch)
tree4acb6d08b8978c51499073b25ec187e3f0e57fc1 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentf4e70be18b104fbb2efeefeb83bea190aec12727 (diff)
Generalize the gather_indices dimension that stores indices
This is now exposed as a index_vector_dim dimension number. Also fixed an off-by-one error in ValidateGatherDimensionNumbers in the expression computing output_shape_rank. PiperOrigin-RevId: 187040748
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc191
1 files changed, 154 insertions, 37 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 7eb120843f..029d2b3b86 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1530,11 +1530,17 @@ TEST_F(ShapeInferenceTest, BadSlice) {
class GatherShapeInferenceTest : public ShapeInferenceTest {
protected:
+ const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
+ const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
const Shape s64_4d_tensor_10_9_8_7_1_ =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
const Shape s64_4d_tensor_10_9_8_7_5_ =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ const Shape s64_4d_tensor_5_10_9_7_6_ =
+ ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
+ const Shape s64_4d_tensor_10_9_5_7_6_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
const Shape f32_5d_tensor_50_49_48_47_46_ =
ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
@@ -1548,7 +1554,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{0},
/*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1}),
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
/*window_bounds=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
@@ -1562,7 +1569,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{1},
/*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0}),
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1),
/*window_bounds=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
@@ -1576,7 +1584,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4},
/*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0}),
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
@@ -1591,7 +1600,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1599,12 +1609,85 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
<< ShapeUtil::HumanString(gather_shape);
}
+TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ EXPECT_TRUE(ShapeUtil::Equal(
+ gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ EXPECT_TRUE(ShapeUtil::Equal(
+ gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
+ // This is equivalent to a dynamic slice.
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0, 1, 2, 3, 4},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
+ // The gather indices "tensor" is a scalar S here that's used to slice out
+ // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0, 1, 2, 3},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/0),
+ /*window_bounds=*/{1, 30, 29, 28, 27}));
+
+ EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
/*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1}),
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1617,7 +1700,8 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
s64_vector_32_, tuple_shape_,
HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
/*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1}),
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1625,25 +1709,13 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
- StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
- s64_vector_32_, s32_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1}),
- /*window_bounds=*/{64, 1});
- ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Gather indices parameter must at least of rank 1"))
- << statusor.status();
-}
-
TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
/*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1}),
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1658,7 +1730,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 8, 7},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1674,7 +1747,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 7},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1690,7 +1764,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 99, 100, 101},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1699,13 +1774,30 @@ TEST_F(GatherShapeInferenceTest,
}
TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 9},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window index 4 in gather op is out of bounds"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{4},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1722,7 +1814,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 19},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1738,7 +1831,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 3},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1755,15 +1849,15 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "There must be exactly as many elements in "
- "gather_dims_to_operand_dims "
- "as there are elements in the last dimension of %gather_indices"))
+ HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
+ "the bound of dimension index_vector_dim=4 of "
+ "gather_indices is 5. These two numbers must be equal."))
<< statusor.status();
}
@@ -1774,7 +1868,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1791,7 +1886,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1808,7 +1904,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{2, 1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{1, 1, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1822,7 +1919,8 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{2},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 1, 300, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1838,7 +1936,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -1855,7 +1954,8 @@ TEST_F(GatherShapeInferenceTest,
HloInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
/*window_bounds=*/{30, 29, 28, 26, 20});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1864,5 +1964,22 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
+TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/32),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather index leaf dimension must be within [0, "
+ "rank(gather_indices) + 1)"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla