diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7cff042a48..8c731ae297 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -682,16 +682,43 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = - ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}); + ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); } +TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); +} + +TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred)); +} + +TEST_F(ShapeInferenceTest, InferInvalidStride) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, + inferred_status.status().code()); +} + TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = - ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}); + ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); ASSERT_FALSE(inferred_status.ok()); ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, inferred_status.status().code()); @@ -700,7 +727,7 @@ TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); auto inferred_status = - ShapeInference::InferSliceShape(vector_shape, {2}, {4}); + ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); ASSERT_TRUE(inferred_status.ok()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); |