aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-11 08:52:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-11 10:05:47 -0700
commit5c8acccfc9e90d694a8394f5522097bfe87379b2 (patch)
tree14a599c7ca3e73bcb09cd565a07da1aa87ef20d5 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent315443bbb01719a9e417826c4a077c47056bc071 (diff)
Using GMock matchers in XLA tests.
Change: 152823724
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc225
1 files changed, 108 insertions, 117 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6f968ded56..255d64b320 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -20,12 +20,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace {
+using ::testing::ContainsRegex;
+
class ShapeInferenceTest : public ::testing::Test {
protected:
// Some handy scalar shapes.
@@ -128,23 +131,21 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(
- inferred_status_error1.status().error_message(),
- testing::ContainsRegex("operands to select must be the same shape"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("operands to select must be the same shape"));
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex("pred operand must have PRED"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("pred operand must have PRED"));
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(
- inferred_status_error3.status().error_message(),
- testing::ContainsRegex("with non-scalar predicate with dimensionality"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("with non-scalar predicate with dimensionality"));
// Tuples have a TUPLE element type and cannot be the pred of a select.
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
@@ -152,9 +153,8 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
ShapeUtil::MakeTupleShape({f32_, f32_}),
ShapeUtil::MakeTupleShape({f32_, f32_}));
ASSERT_FALSE(inferred_status_error4.ok());
- ASSERT_MATCH(
- inferred_status_error4.status().error_message(),
- testing::ContainsRegex("pred operand must have PRED element type"));
+ ASSERT_THAT(inferred_status_error4.status().error_message(),
+ ContainsRegex("pred operand must have PRED element type"));
}
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
@@ -298,8 +298,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
operand_shape_, select_program_shape_, window_, source_shape_fail,
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
- ASSERT_MATCH(inferred_status_fail.status().error_message(),
- testing::ContainsRegex("source shape does not match"));
+ ASSERT_THAT(inferred_status_fail.status().error_message(),
+ ContainsRegex("source shape does not match"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
@@ -309,9 +309,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
operand_shape_, select_program_shape_fail, window_, source_shape_,
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
- ASSERT_MATCH(
- inferred_status_fail.status().error_message(),
- testing::ContainsRegex("select function must take 2 parameters"));
+ ASSERT_THAT(inferred_status_fail.status().error_message(),
+ ContainsRegex("select function must take 2 parameters"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
@@ -321,8 +320,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
operand_shape_, select_program_shape_fail, window_, source_shape_,
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
- ASSERT_MATCH(inferred_status_fail.status().error_message(),
- testing::ContainsRegex("select function must have rank-0 PRED"));
+ ASSERT_THAT(inferred_status_fail.status().error_message(),
+ ContainsRegex("select function must have rank-0 PRED"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
@@ -332,8 +331,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
operand_shape_, select_program_shape_fail, window_, source_shape_,
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
- ASSERT_MATCH(inferred_status_fail.status().error_message(),
- testing::ContainsRegex("select function's first parameter"));
+ ASSERT_THAT(inferred_status_fail.status().error_message(),
+ ContainsRegex("select function's first parameter"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
@@ -343,8 +342,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
operand_shape_, select_program_shape_fail, window_, source_shape_,
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
- ASSERT_MATCH(inferred_status_fail.status().error_message(),
- testing::ContainsRegex("select function's second parameter"));
+ ASSERT_THAT(inferred_status_fail.status().error_message(),
+ ContainsRegex("select function's second parameter"));
}
TEST_F(ShapeInferenceTest, Convolve) {
@@ -498,8 +497,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
auto inferred_status =
ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
ASSERT_FALSE(inferred_status.ok());
- ASSERT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("each dimension exactly once"));
+ ASSERT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("each dimension exactly once"));
}
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
@@ -536,43 +535,42 @@ TEST_F(ShapeInferenceTest, Map) {
auto no_args_error = ShapeInference::InferMapShape(
{}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(no_args_error.ok());
- ASSERT_MATCH(no_args_error.status().error_message(),
- testing::ContainsRegex("expects at least one argument"));
+ ASSERT_THAT(no_args_error.status().error_message(),
+ ContainsRegex("expects at least one argument"));
auto args_diff_shapes_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_64_},
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(args_diff_shapes_error.ok());
- ASSERT_MATCH(
- args_diff_shapes_error.status().error_message(),
- testing::ContainsRegex("requires all operands to have the same shape"));
+ ASSERT_THAT(args_diff_shapes_error.status().error_message(),
+ ContainsRegex("requires all operands to have the same shape"));
auto arity_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
ASSERT_FALSE(arity_error.ok());
- ASSERT_MATCH(arity_error.status().error_message(),
- testing::ContainsRegex("function arity must match"));
+ ASSERT_THAT(arity_error.status().error_message(),
+ ContainsRegex("function arity must match"));
auto output_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_));
ASSERT_FALSE(output_shape_error.ok());
- ASSERT_MATCH(output_shape_error.status().error_message(),
- testing::ContainsRegex("result has to be a scalar"));
+ ASSERT_THAT(output_shape_error.status().error_message(),
+ ContainsRegex("result has to be a scalar"));
auto param_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_));
ASSERT_FALSE(param_shape_error.ok());
- ASSERT_MATCH(param_shape_error.status().error_message(),
- testing::ContainsRegex("parameter has to be a scalar"));
+ ASSERT_THAT(param_shape_error.status().error_message(),
+ ContainsRegex("parameter has to be a scalar"));
auto param_element_type_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({f32_, s32_}, f32_));
ASSERT_FALSE(param_element_type_error.ok());
- ASSERT_MATCH(param_element_type_error.status().error_message(),
- testing::ContainsRegex("parameter type has to match argument"));
+ ASSERT_THAT(param_element_type_error.status().error_message(),
+ ContainsRegex("parameter type has to match argument"));
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
@@ -583,26 +581,26 @@ TEST_F(ShapeInferenceTest, Map) {
auto inferred_status_error1 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(inferred_status_error1.status().error_message(),
- testing::ContainsRegex("arity must match number of arguments"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("arity must match number of arguments"));
auto inferred_status_error2 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_));
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex("has to be a scalar"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("has to be a scalar"));
auto inferred_status_error3 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_));
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(inferred_status_error3.status().error_message(),
- testing::ContainsRegex("has to be a scalar"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("has to be a scalar"));
auto inferred_status_error5 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_));
ASSERT_FALSE(inferred_status_error5.ok());
- ASSERT_MATCH(inferred_status_error5.status().error_message(),
- testing::ContainsRegex("parameter type has to match argument"));
+ ASSERT_THAT(inferred_status_error5.status().error_message(),
+ ContainsRegex("parameter type has to match argument"));
}
TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
@@ -656,8 +654,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
to_apply);
EXPECT_FALSE(inferred_status.ok());
- EXPECT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("out-of-bounds dimension"));
+ EXPECT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("out-of-bounds dimension"));
}
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
@@ -666,8 +664,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
- EXPECT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("take 2 parameters"));
+ EXPECT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("take 2 parameters"));
}
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
@@ -676,8 +674,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
- EXPECT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("first parameter shape differs"));
+ EXPECT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("first parameter shape differs"));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
@@ -819,8 +817,8 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) {
auto inferred_status =
ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {});
ASSERT_FALSE(inferred_status.ok());
- ASSERT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("dot only supports rank"));
+ ASSERT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("dot only supports rank"));
}
// 3D <dot> 2D: error
@@ -828,8 +826,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
auto inferred_status = ShapeInference::InferBinaryOpShape(
BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {});
ASSERT_FALSE(inferred_status.ok());
- ASSERT_MATCH(inferred_status.status().error_message(),
- testing::ContainsRegex("dot only supports rank"));
+ ASSERT_THAT(inferred_status.status().error_message(),
+ ContainsRegex("dot only supports rank"));
}
// vector <dot> vector -> scalar
@@ -941,46 +939,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, vec8, {});
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(inferred_status_error1.status().error_message(),
- testing::ContainsRegex("automatic"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("automatic"));
// broadcast_dimension out of bounds for tensor's rank
auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, vec8, {3});
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(
- inferred_status_error2.status().error_message(),
- testing::ContainsRegex("broadcast dimension number .* too large"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("broadcast dimension number .* too large"));
// broadcast_dimension doesn't match corresponding dimension
auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, vec8, {0});
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(inferred_status_error3.status().error_message(),
- testing::ContainsRegex("broadcast dimension 0 mismatch"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("broadcast dimension 0 mismatch"));
// broadcast_dimensions list too long
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
ASSERT_FALSE(inferred_status_error4.ok());
- ASSERT_MATCH(
- inferred_status_error4.status().error_message(),
- testing::ContainsRegex("size of broadcast_dimensions has to match"));
+ ASSERT_THAT(inferred_status_error4.status().error_message(),
+ ContainsRegex("size of broadcast_dimensions has to match"));
// there's a dimension above the rank of the tensor
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
ASSERT_FALSE(inferred_status_error5.ok());
- ASSERT_MATCH(
- inferred_status_error5.status().error_message(),
- testing::ContainsRegex("broadcast dimension number .* too large"));
+ ASSERT_THAT(inferred_status_error5.status().error_message(),
+ ContainsRegex("broadcast dimension number .* too large"));
// broadcasting dimensions don't match in this order
auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
ASSERT_FALSE(inferred_status_error6.ok());
- ASSERT_MATCH(inferred_status_error6.status().error_message(),
- testing::ContainsRegex("broadcast dimension 0 mismatch"));
+ ASSERT_THAT(inferred_status_error6.status().error_message(),
+ ContainsRegex("broadcast dimension 0 mismatch"));
// The following two tests make sure that broadcasting dimensions are listed
// in a proper (strictly increasing) order, even if the lower-rank array
@@ -988,14 +983,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
ASSERT_FALSE(inferred_status_error7.ok());
- ASSERT_MATCH(inferred_status_error7.status().error_message(),
- testing::ContainsRegex("broadcast dimensions order is wrong"));
+ ASSERT_THAT(inferred_status_error7.status().error_message(),
+ ContainsRegex("broadcast dimensions order is wrong"));
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
ASSERT_FALSE(inferred_status_error8.ok());
- ASSERT_MATCH(inferred_status_error8.status().error_message(),
- testing::ContainsRegex("broadcast dimensions order is wrong"));
+ ASSERT_THAT(inferred_status_error8.status().error_message(),
+ ContainsRegex("broadcast dimensions order is wrong"));
}
// Tests for the while instruction with proper shapes.
@@ -1020,30 +1015,30 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
auto inferred_status_error1 =
ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(inferred_status_error1.status().error_message(),
- testing::ContainsRegex("condition must take 1 arguments"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("condition must take 1 arguments"));
auto bad_shape_2 =
ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
auto inferred_status_error2 =
ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex("body must take 1 arguments"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("body must take 1 arguments"));
auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
auto inferred_status_error3 =
ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(inferred_status_error3.status().error_message(),
- testing::ContainsRegex("condition must return a boolean"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("condition must return a boolean"));
auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
auto inferred_status_error4 =
ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
ASSERT_FALSE(inferred_status_error4.ok());
- ASSERT_MATCH(inferred_status_error4.status().error_message(),
- testing::ContainsRegex("parameter of condition and body"));
+ ASSERT_THAT(inferred_status_error4.status().error_message(),
+ ContainsRegex("parameter of condition and body"));
}
// Tests for the concatenate instruction with proper shapes.
@@ -1073,49 +1068,45 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
auto inferred_status_error1 =
ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(
- inferred_status_error1.status().error_message(),
- testing::ContainsRegex("Concatenate expects at least one argument"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("Concatenate expects at least one argument"));
auto inferred_status_error2 =
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex(
- "dimension to concatenate along out of bounds: -1"));
+ ASSERT_THAT(
+ inferred_status_error2.status().error_message(),
+ ContainsRegex("dimension to concatenate along out of bounds: -1"));
auto inferred_status_error3 =
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(inferred_status_error3.status().error_message(),
- testing::ContainsRegex(
- "dimension to concatenate along out of bounds: 1"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("dimension to concatenate along out of bounds: 1"));
Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
{&vector_32_, &tuple}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error4.ok());
- ASSERT_MATCH(
- inferred_status_error4.status().error_message(),
- testing::ContainsRegex(
- "Expected non-tuple argument for operand of concatenation."));
+ ASSERT_THAT(inferred_status_error4.status().error_message(),
+ ContainsRegex(
+ "Expected non-tuple argument for operand of concatenation."));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
{&vector_32_, &vector_s32}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error5.ok());
- ASSERT_MATCH(inferred_status_error5.status().error_message(),
- testing::ContainsRegex(
- "cannot concatenate arrays with different element types"));
+ ASSERT_THAT(
+ inferred_status_error5.status().error_message(),
+ ContainsRegex("cannot concatenate arrays with different element types"));
auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
{&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error6.ok());
- ASSERT_MATCH(
- inferred_status_error6.status().error_message(),
- testing::ContainsRegex("cannot concatenate arrays that differ in "
- "dimensions other than the one being "
- "concatenated"));
+ ASSERT_THAT(inferred_status_error6.status().error_message(),
+ ContainsRegex("cannot concatenate arrays that differ in "
+ "dimensions other than the one being "
+ "concatenated"));
}
TEST_F(ShapeInferenceTest, Pad) {
@@ -1156,27 +1147,27 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
auto inferred_status_error0 =
ShapeInference::InferReverseShape(input_shape, {0, 2});
ASSERT_FALSE(inferred_status_error0.ok());
- ASSERT_MATCH(inferred_status_error0.status().error_message(),
- testing::ContainsRegex("out-of-bounds"));
+ ASSERT_THAT(inferred_status_error0.status().error_message(),
+ ContainsRegex("out-of-bounds"));
auto inferred_status_error1 =
ShapeInference::InferReverseShape(input_shape, {0, -1});
ASSERT_FALSE(inferred_status_error1.ok());
- ASSERT_MATCH(inferred_status_error1.status().error_message(),
- testing::ContainsRegex("out-of-bounds"));
+ ASSERT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("out-of-bounds"));
auto inferred_status_error2 =
ShapeInference::InferReverseShape(input_shape, {0, 0});
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex("duplicated"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("duplicated"));
Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
auto inferred_status_error3 =
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
- ASSERT_MATCH(inferred_status_error3.status().error_message(),
- testing::ContainsRegex("Expected non-tuple argument"));
+ ASSERT_THAT(inferred_status_error3.status().error_message(),
+ ContainsRegex("Expected non-tuple argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1196,20 +1187,20 @@ TEST_F(ShapeInferenceTest, Call) {
auto inferred_status_error0 = ShapeInference::InferCallShape(
{}, ShapeUtil::MakeProgramShape({f32_}, f32_));
EXPECT_FALSE(inferred_status_error0.ok());
- EXPECT_MATCH(inferred_status_error0.status().error_message(),
- testing::ContainsRegex("arity must match"));
+ EXPECT_THAT(inferred_status_error0.status().error_message(),
+ ContainsRegex("arity must match"));
auto inferred_status_error1 = ShapeInference::InferCallShape(
{&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
EXPECT_FALSE(inferred_status_error1.ok());
- EXPECT_MATCH(inferred_status_error1.status().error_message(),
- testing::ContainsRegex("arity must match"));
+ EXPECT_THAT(inferred_status_error1.status().error_message(),
+ ContainsRegex("arity must match"));
auto inferred_status_error2 = ShapeInference::InferCallShape(
{&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
EXPECT_FALSE(inferred_status_error2.ok());
- EXPECT_MATCH(inferred_status_error2.status().error_message(),
- testing::ContainsRegex("parameter must match argument"));
+ EXPECT_THAT(inferred_status_error2.status().error_message(),
+ ContainsRegex("parameter must match argument"));
}
TEST_F(ShapeInferenceTest, Transpose) {