diff options
author | 2018-03-02 13:37:41 -0800 | |
---|---|---|
committer | 2018-03-02 13:41:16 -0800 | |
commit | 4b038da7006c81e3e6cd542a7015d4a84d5c2385 (patch) | |
tree | c8c93bfe396177b45922834deb8dc03b5b20f0c6 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 1e2c2f1cddd52ed86f8d5d7f10faa6498f13dded (diff) |
Make shape inference error messages more consistent.
PiperOrigin-RevId: 187662562
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 59 |
1 files changed, 29 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 029d2b3b86..0e61994a78 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -135,7 +135,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("operands to select must be the same shape")); + HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); @@ -340,7 +340,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("source shape does not match")); + HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -351,7 +351,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must take 2 parameters")); + HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -362,7 +362,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must have rank-0 PRED")); + HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -373,7 +373,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's first parameter")); + HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -384,7 +384,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's second parameter")); + HasSubstr("Select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -906,7 +906,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("Dot only supports rank")); } // 3D <dot> 2D: error @@ -918,7 +918,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch and contracting dimension number mismatch")); + HasSubstr("Batch and contracting dimension number mismatch")); } // vector <dot> vector -> scalar @@ -1024,7 +1024,7 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("must specify one contracting dimension for both " + HasSubstr("Must specify one contracting dimension for both " "lhs and rhs")); } @@ -1044,7 +1044,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension numbers and sizes must match")); } // BatchMatMul with different batch dimension numbers fails. @@ -1063,7 +1063,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers must precede non-batch")); + HasSubstr("Batch dimension numbers must precede non-batch")); } // BatchMatMul with out-of-range dimension numbers fails. @@ -1166,42 +1166,42 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("automatic")); + HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("size of broadcast_dimensions has to match")); + HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -1210,13 +1210,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1242,7 +1242,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("condition must take 1 arguments")); + HasSubstr("Condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); @@ -1250,14 +1250,14 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("body must take 1 arguments")); + HasSubstr("Body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("condition must return a boolean")); + HasSubstr("Condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = @@ -1301,13 +1301,13 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: -1")); + HasSubstr("dimension out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: 1")); + HasSubstr("dimension out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( @@ -1315,21 +1315,20 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT( - inferred_status_error5.status().error_message(), - HasSubstr("cannot concatenate arrays with different element types")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("cannot concatenate arrays that differ in " + HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } @@ -1467,7 +1466,7 @@ TEST_F(ShapeInferenceTest, Conditional) { ShapeUtil::MakeProgramShape({vector_64_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - HasSubstr("predicate must be a boolean")); + HasSubstr("Predicate must be a boolean")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, |