diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-11 08:52:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-11 10:05:47 -0700 |
commit | 5c8acccfc9e90d694a8394f5522097bfe87379b2 (patch) | |
tree | 14a599c7ca3e73bcb09cd565a07da1aa87ef20d5 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 315443bbb01719a9e417826c4a077c47056bc071 (diff) |
Using GMock matchers in XLA tests.
Change: 152823724
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 225 |
1 files changed, 108 insertions, 117 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6f968ded56..255d64b320 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -20,12 +20,15 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace { +using ::testing::ContainsRegex; + class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. @@ -128,23 +131,21 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("operands to select must be the same shape")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("pred operand must have PRED")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH( - inferred_status_error3.status().error_message(), - testing::ContainsRegex("with non-scalar predicate with dimensionality")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( @@ -152,9 +153,8 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("pred operand must have PRED element type")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + ContainsRegex("pred operand must have PRED element type")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { @@ -298,8 +298,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { operand_shape_, select_program_shape_, window_, source_shape_fail, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("source shape does not match")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + ContainsRegex("source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -309,9 +309,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH( - inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must take 2 parameters")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + ContainsRegex("select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -321,8 +320,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must have rank-0 PRED")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + ContainsRegex("select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -332,8 +331,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's first parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + ContainsRegex("select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -343,8 +342,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's second parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + ContainsRegex("select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -498,8 +497,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { auto inferred_status = ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("each dimension exactly once")); + ASSERT_THAT(inferred_status.status().error_message(), + ContainsRegex("each dimension exactly once")); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { @@ -536,43 +535,42 @@ TEST_F(ShapeInferenceTest, Map) { auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(no_args_error.ok()); - ASSERT_MATCH(no_args_error.status().error_message(), - testing::ContainsRegex("expects at least one argument")); + ASSERT_THAT(no_args_error.status().error_message(), + ContainsRegex("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(args_diff_shapes_error.ok()); - ASSERT_MATCH( - args_diff_shapes_error.status().error_message(), - testing::ContainsRegex("requires all operands to have the same shape")); + ASSERT_THAT(args_diff_shapes_error.status().error_message(), + ContainsRegex("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); ASSERT_FALSE(arity_error.ok()); - ASSERT_MATCH(arity_error.status().error_message(), - testing::ContainsRegex("function arity must match")); + ASSERT_THAT(arity_error.status().error_message(), + ContainsRegex("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); ASSERT_FALSE(output_shape_error.ok()); - ASSERT_MATCH(output_shape_error.status().error_message(), - testing::ContainsRegex("result has to be a scalar")); + ASSERT_THAT(output_shape_error.status().error_message(), + ContainsRegex("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); ASSERT_FALSE(param_shape_error.ok()); - ASSERT_MATCH(param_shape_error.status().error_message(), - testing::ContainsRegex("parameter has to be a scalar")); + ASSERT_THAT(param_shape_error.status().error_message(), + ContainsRegex("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); ASSERT_FALSE(param_element_type_error.ok()); - ASSERT_MATCH(param_element_type_error.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(param_element_type_error.status().error_message(), + ContainsRegex("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); @@ -583,26 +581,26 @@ TEST_F(ShapeInferenceTest, Map) { auto inferred_status_error1 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match number of arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + ContainsRegex("parameter type has to match argument")); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -656,8 +654,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("out-of-bounds dimension")); + EXPECT_THAT(inferred_status.status().error_message(), + ContainsRegex("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { @@ -666,8 +664,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("take 2 parameters")); + EXPECT_THAT(inferred_status.status().error_message(), + ContainsRegex("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { @@ -676,8 +674,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("first parameter shape differs")); + EXPECT_THAT(inferred_status.status().error_message(), + ContainsRegex("first parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -819,8 +817,8 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { auto inferred_status = ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + ContainsRegex("dot only supports rank")); } // 3D <dot> 2D: error @@ -828,8 +826,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { auto inferred_status = ShapeInference::InferBinaryOpShape( BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + ContainsRegex("dot only supports rank")); } // vector <dot> vector -> scalar @@ -941,46 +939,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("automatic")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH( - inferred_status_error2.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("size of broadcast_dimensions has to match")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + ContainsRegex("size of broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH( - inferred_status_error5.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH(inferred_status_error6.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + ContainsRegex("broadcast dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -988,14 +983,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); - ASSERT_MATCH(inferred_status_error7.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error7.status().error_message(), + ContainsRegex("broadcast dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); - ASSERT_MATCH(inferred_status_error8.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error8.status().error_message(), + ContainsRegex("broadcast dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1020,30 +1015,30 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("condition must take 1 arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); auto inferred_status_error2 = ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("body must take 1 arguments")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("condition must return a boolean")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH(inferred_status_error4.status().error_message(), - testing::ContainsRegex("parameter of condition and body")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + ContainsRegex("parameter of condition and body")); } // Tests for the concatenate instruction with proper shapes. @@ -1073,49 +1068,45 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("Concatenate expects at least one argument")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("Concatenate expects at least one argument")); auto inferred_status_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: -1")); + ASSERT_THAT( + inferred_status_error2.status().error_message(), + ContainsRegex("dimension to concatenate along out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: 1")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("dimension to concatenate along out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( {&vector_32_, &tuple}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex( - "Expected non-tuple argument for operand of concatenation.")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + ContainsRegex( + "Expected non-tuple argument for operand of concatenation.")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex( - "cannot concatenate arrays with different element types")); + ASSERT_THAT( + inferred_status_error5.status().error_message(), + ContainsRegex("cannot concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH( - inferred_status_error6.status().error_message(), - testing::ContainsRegex("cannot concatenate arrays that differ in " - "dimensions other than the one being " - "concatenated")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + ContainsRegex("cannot concatenate arrays that differ in " + "dimensions other than the one being " + "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { @@ -1156,27 +1147,27 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { auto inferred_status_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); ASSERT_FALSE(inferred_status_error0.ok()); - ASSERT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error0.status().error_message(), + ContainsRegex("out-of-bounds")); auto inferred_status_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("out-of-bounds")); auto inferred_status_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("duplicated")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("duplicated")); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); auto inferred_status_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("Expected non-tuple argument")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + ContainsRegex("Expected non-tuple argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1196,20 +1187,20 @@ TEST_F(ShapeInferenceTest, Call) { auto inferred_status_error0 = ShapeInference::InferCallShape( {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error0.status().error_message(), + ContainsRegex("arity must match")); auto inferred_status_error1 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error1.status().error_message(), + ContainsRegex("arity must match")); auto inferred_status_error2 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("parameter must match argument")); + EXPECT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { |