aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-12 22:19:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 22:23:34 -0700
commit1b765165987f5277e294251c118f321166c70932 (patch)
treea3b879ec7d36389da1aa66a855aed762917140f8 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent385bb78761489cfa6d6808f239fe884152e71653 (diff)
[XLA] Split out HloGatherInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 204421652
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc122
1 files changed, 63 insertions, 59 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index bafe14d6f4..9b1ce143c6 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1543,45 +1544,45 @@ class GatherShapeInferenceTest : public ShapeInferenceTest {
};
TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_vector_32_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
+ /*window_bounds=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
}
TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
- /*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_vector_32_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{1},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1),
+ /*window_bounds=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
}
TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4),
+ /*window_bounds=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1592,7 +1593,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1609,7 +1610,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1627,7 +1628,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1646,7 +1647,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{0, 1, 2, 3, 4},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1664,7 +1665,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{0, 1, 2, 3},
/*elided_window_dims=*/{0},
/*gather_dims_to_operand_dims=*/{0},
@@ -1679,10 +1680,11 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/1),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1693,10 +1695,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/0),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1707,10 +1710,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/0),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1722,7 +1726,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 8, 7},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1739,7 +1743,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 7},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1756,7 +1760,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 99, 100, 101},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1772,7 +1776,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 9},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1788,7 +1792,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{4},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1806,7 +1810,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 19},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1823,7 +1827,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 3},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1841,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
@@ -1860,7 +1864,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
@@ -1878,7 +1882,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
@@ -1896,7 +1900,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{2, 1},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1911,7 +1915,7 @@ TEST_F(GatherShapeInferenceTest,
TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{2},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1928,7 +1932,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1946,7 +1950,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{1},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1962,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest,
TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},