diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-12 09:29:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-12 10:50:37 -0700 |
commit | 9992074410d0b8d7102b7a63ff5f01a1a4554357 (patch) | |
tree | 9d4be04123492819014119cc98f8beafecfefd63 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 26332f733e3432b815e2d3dbd5a2f123a38af4e7 (diff) |
Simplified XLA tests by using gmock matchers.
Change: 152954991
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 108 |
1 files changed, 54 insertions, 54 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 255d64b320..7cff042a48 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -28,6 +28,7 @@ namespace xla { namespace { using ::testing::ContainsRegex; +using ::testing::HasSubstr; class ShapeInferenceTest : public ::testing::Test { protected: @@ -132,20 +133,20 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("operands to select must be the same shape")); + HasSubstr("operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("pred operand must have PRED")); + HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("with non-scalar predicate with dimensionality")); + HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( @@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - ContainsRegex("pred operand must have PRED element type")); + HasSubstr("pred operand must have PRED element type")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { @@ -299,7 +300,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - ContainsRegex("source shape does not match")); + HasSubstr("source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -310,7 +311,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - ContainsRegex("select function must take 2 parameters")); + HasSubstr("select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -321,7 +322,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - ContainsRegex("select function must have rank-0 PRED")); + HasSubstr("select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -332,7 +333,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - ContainsRegex("select function's first parameter")); + HasSubstr("select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -343,7 +344,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - ContainsRegex("select function's second parameter")); + HasSubstr("select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -498,7 +499,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - ContainsRegex("each dimension exactly once")); + HasSubstr("each dimension exactly once")); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { @@ -536,41 +537,41 @@ TEST_F(ShapeInferenceTest, Map) { {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().error_message(), - ContainsRegex("expects at least one argument")); + HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().error_message(), - ContainsRegex("requires all operands to have the same shape")); + HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().error_message(), - ContainsRegex("function arity must match")); + HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().error_message(), - ContainsRegex("result has to be a scalar")); + HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().error_message(), - ContainsRegex("parameter has to be a scalar")); + HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().error_message(), - ContainsRegex("parameter type has to match argument")); + HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); @@ -582,25 +583,25 @@ TEST_F(ShapeInferenceTest, Map) { {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("arity must match number of arguments")); + HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("has to be a scalar")); + HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("has to be a scalar")); + HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), - ContainsRegex("parameter type has to match argument")); + HasSubstr("parameter type has to match argument")); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -655,7 +656,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - ContainsRegex("out-of-bounds dimension")); + HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { @@ -665,7 +666,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - ContainsRegex("take 2 parameters")); + HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { @@ -675,7 +676,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - ContainsRegex("first parameter shape differs")); + HasSubstr("first parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -818,7 +819,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - ContainsRegex("dot only supports rank")); + HasSubstr("dot only supports rank")); } // 3D <dot> 2D: error @@ -827,7 +828,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - ContainsRegex("dot only supports rank")); + HasSubstr("dot only supports rank")); } // vector <dot> vector -> scalar @@ -940,7 +941,7 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("automatic")); + HasSubstr("automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( @@ -954,14 +955,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("broadcast dimension 0 mismatch")); + HasSubstr("broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - ContainsRegex("size of broadcast_dimensions has to match")); + HasSubstr("size of broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( @@ -975,7 +976,7 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - ContainsRegex("broadcast dimension 0 mismatch")); + HasSubstr("broadcast dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -984,13 +985,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), - ContainsRegex("broadcast dimensions order is wrong")); + HasSubstr("broadcast dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), - ContainsRegex("broadcast dimensions order is wrong")); + HasSubstr("broadcast dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1016,7 +1017,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("condition must take 1 arguments")); + HasSubstr("condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); @@ -1024,21 +1025,21 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("body must take 1 arguments")); + HasSubstr("body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("condition must return a boolean")); + HasSubstr("condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - ContainsRegex("parameter of condition and body")); + HasSubstr("parameter of condition and body")); } // Tests for the concatenate instruction with proper shapes. @@ -1069,28 +1070,27 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ShapeInference::InferConcatOpShape({}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("Concatenate expects at least one argument")); + HasSubstr("Concatenate expects at least one argument")); auto inferred_status_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT( - inferred_status_error2.status().error_message(), - ContainsRegex("dimension to concatenate along out of bounds: -1")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("dimension to concatenate along out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("dimension to concatenate along out of bounds: 1")); + HasSubstr("dimension to concatenate along out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( {&vector_32_, &tuple}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().error_message(), - ContainsRegex( - "Expected non-tuple argument for operand of concatenation.")); + ASSERT_THAT( + inferred_status_error4.status().error_message(), + HasSubstr("Expected non-tuple argument for operand of concatenation.")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( @@ -1098,15 +1098,15 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT( inferred_status_error5.status().error_message(), - ContainsRegex("cannot concatenate arrays with different element types")); + HasSubstr("cannot concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - ContainsRegex("cannot concatenate arrays that differ in " - "dimensions other than the one being " - "concatenated")); + HasSubstr("cannot concatenate arrays that differ in " + "dimensions other than the one being " + "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { @@ -1148,26 +1148,26 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { ShapeInference::InferReverseShape(input_shape, {0, 2}); ASSERT_FALSE(inferred_status_error0.ok()); ASSERT_THAT(inferred_status_error0.status().error_message(), - ContainsRegex("out-of-bounds")); + HasSubstr("out-of-bounds")); auto inferred_status_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("out-of-bounds")); + HasSubstr("out-of-bounds")); auto inferred_status_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("duplicated")); + HasSubstr("duplicated")); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); auto inferred_status_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - ContainsRegex("Expected non-tuple argument")); + HasSubstr("Expected non-tuple argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1188,19 +1188,19 @@ TEST_F(ShapeInferenceTest, Call) { {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - ContainsRegex("arity must match")); + HasSubstr("arity must match")); auto inferred_status_error1 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_FALSE(inferred_status_error1.ok()); EXPECT_THAT(inferred_status_error1.status().error_message(), - ContainsRegex("arity must match")); + HasSubstr("arity must match")); auto inferred_status_error2 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); EXPECT_FALSE(inferred_status_error2.ok()); EXPECT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("parameter must match argument")); + HasSubstr("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { |