aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-16 15:29:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 15:33:12 -0800
commitba019dc689d6393d8dba04ca57e8b01b374db14f (patch)
tree6bd132cd6a1d3b6c8c833cb3e575db571ebd19a1 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent1873ed4faab980ad239c06e8b92b8f4a85154fe3 (diff)
[XLA] Add some plumbing, documentation, verification and shape inference for Gather
Pretty much everything other than HLO verification and shape inference will fail for Gather with Unimplemented. Note that this CL is intentionally incomplete -- I figured it would be nicer to get some of the boiler-platey stuff out of the way early. Let me know if you want me to send in a larger but more complete CL instead. PiperOrigin-RevId: 186055521
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc341
1 files changed, 339 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 026c021165..7eb120843f 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -18,15 +18,16 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
+using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -1527,5 +1528,341 @@ TEST_F(ShapeInferenceTest, BadSlice) {
<< statusor.status();
}
+class GatherShapeInferenceTest : public ShapeInferenceTest {
+ protected:
+ const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
+ const Shape s64_4d_tensor_10_9_8_7_1_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
+ const Shape s64_4d_tensor_10_9_8_7_5_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ const Shape f32_5d_tensor_50_49_48_47_46_ =
+ ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
+ {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
+};
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{1},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ tuple_shape_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for input"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, tuple_shape_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for gather indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, s32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must at least of rank 1"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must be an integral tensor"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 8, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must be ascending"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 99, 100, 101},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window index 2 in gather op is out of bounds"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{4},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("All components of the window index in a gather op must either "
+ "be a output window index or explicitly elided"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 19},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ "range is [0, 5), got: 19"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 3},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "There must be exactly as many elements in "
+ "gather_dims_to_operand_dims "
+ "as there are elements in the last dimension of %gather_indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
+ "[0, 5), got: 4->7"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{2, 1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{1, 1, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("elided_window_dims in gather op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{2},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 1, 300, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window bound at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Gather op must have one window bound for every input dimension"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26, 20});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op can only elide window indices with bound 1, "
+ "but bound is 29 for index 1 at position 0"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla