aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc33
1 files changed, 30 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 7cff042a48..8c731ae297 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -682,16 +682,43 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64});
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
ASSERT_IS_OK(inferred_status.status());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
}
+TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
+}
+
+TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
+}
+
+TEST_F(ShapeInferenceTest, InferInvalidStride) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
+ inferred_status.status().code());
+}
+
TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2});
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
ASSERT_FALSE(inferred_status.ok());
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
inferred_status.status().code());
@@ -700,7 +727,7 @@ TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
auto inferred_status =
- ShapeInference::InferSliceShape(vector_shape, {2}, {4});
+ ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
ASSERT_TRUE(inferred_status.ok());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));