aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-12 09:29:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-12 10:50:37 -0700
commit9992074410d0b8d7102b7a63ff5f01a1a4554357 (patch)
tree9d4be04123492819014119cc98f8beafecfefd63 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent26332f733e3432b815e2d3dbd5a2f123a38af4e7 (diff)
Simplified XLA tests by using gmock matchers.
Change: 152954991
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc108
1 files changed, 54 insertions, 54 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 255d64b320..7cff042a48 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -28,6 +28,7 @@ namespace xla {
namespace {
using ::testing::ContainsRegex;
+using ::testing::HasSubstr;
class ShapeInferenceTest : public ::testing::Test {
protected:
@@ -132,20 +133,20 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("operands to select must be the same shape"));
+ HasSubstr("operands to select must be the same shape"));
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
- ContainsRegex("pred operand must have PRED"));
+ HasSubstr("pred operand must have PRED"));
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("with non-scalar predicate with dimensionality"));
+ HasSubstr("with non-scalar predicate with dimensionality"));
// Tuples have a TUPLE element type and cannot be the pred of a select.
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
@@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
ShapeUtil::MakeTupleShape({f32_, f32_}));
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
- ContainsRegex("pred operand must have PRED element type"));
+ HasSubstr("pred operand must have PRED element type"));
}
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
@@ -299,7 +300,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
ASSERT_THAT(inferred_status_fail.status().error_message(),
- ContainsRegex("source shape does not match"));
+ HasSubstr("source shape does not match"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
@@ -310,7 +311,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
ASSERT_THAT(inferred_status_fail.status().error_message(),
- ContainsRegex("select function must take 2 parameters"));
+ HasSubstr("select function must take 2 parameters"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
@@ -321,7 +322,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
ASSERT_THAT(inferred_status_fail.status().error_message(),
- ContainsRegex("select function must have rank-0 PRED"));
+ HasSubstr("select function must have rank-0 PRED"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
@@ -332,7 +333,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
ASSERT_THAT(inferred_status_fail.status().error_message(),
- ContainsRegex("select function's first parameter"));
+ HasSubstr("select function's first parameter"));
}
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
@@ -343,7 +344,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
init_value_shape_, scatter_program_shape_);
ASSERT_FALSE(inferred_status_fail.ok());
ASSERT_THAT(inferred_status_fail.status().error_message(),
- ContainsRegex("select function's second parameter"));
+ HasSubstr("select function's second parameter"));
}
TEST_F(ShapeInferenceTest, Convolve) {
@@ -498,7 +499,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
- ContainsRegex("each dimension exactly once"));
+ HasSubstr("each dimension exactly once"));
}
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
@@ -536,41 +537,41 @@ TEST_F(ShapeInferenceTest, Map) {
{}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(no_args_error.ok());
ASSERT_THAT(no_args_error.status().error_message(),
- ContainsRegex("expects at least one argument"));
+ HasSubstr("expects at least one argument"));
auto args_diff_shapes_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_64_},
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(args_diff_shapes_error.ok());
ASSERT_THAT(args_diff_shapes_error.status().error_message(),
- ContainsRegex("requires all operands to have the same shape"));
+ HasSubstr("requires all operands to have the same shape"));
auto arity_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
ASSERT_FALSE(arity_error.ok());
ASSERT_THAT(arity_error.status().error_message(),
- ContainsRegex("function arity must match"));
+ HasSubstr("function arity must match"));
auto output_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_));
ASSERT_FALSE(output_shape_error.ok());
ASSERT_THAT(output_shape_error.status().error_message(),
- ContainsRegex("result has to be a scalar"));
+ HasSubstr("result has to be a scalar"));
auto param_shape_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_));
ASSERT_FALSE(param_shape_error.ok());
ASSERT_THAT(param_shape_error.status().error_message(),
- ContainsRegex("parameter has to be a scalar"));
+ HasSubstr("parameter has to be a scalar"));
auto param_element_type_error = ShapeInference::InferMapShape(
{&vector_32_, &vector_32_},
ShapeUtil::MakeProgramShape({f32_, s32_}, f32_));
ASSERT_FALSE(param_element_type_error.ok());
ASSERT_THAT(param_element_type_error.status().error_message(),
- ContainsRegex("parameter type has to match argument"));
+ HasSubstr("parameter type has to match argument"));
Shape arg = ShapeUtil::MakeShape(F32, {20});
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
@@ -582,25 +583,25 @@ TEST_F(ShapeInferenceTest, Map) {
{&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("arity must match number of arguments"));
+ HasSubstr("arity must match number of arguments"));
auto inferred_status_error2 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_));
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
- ContainsRegex("has to be a scalar"));
+ HasSubstr("has to be a scalar"));
auto inferred_status_error3 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_));
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("has to be a scalar"));
+ HasSubstr("has to be a scalar"));
auto inferred_status_error5 = ShapeInference::InferMapShape(
{&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_));
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
- ContainsRegex("parameter type has to match argument"));
+ HasSubstr("parameter type has to match argument"));
}
TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
@@ -655,7 +656,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- ContainsRegex("out-of-bounds dimension"));
+ HasSubstr("out-of-bounds dimension"));
}
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
@@ -665,7 +666,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- ContainsRegex("take 2 parameters"));
+ HasSubstr("take 2 parameters"));
}
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
@@ -675,7 +676,7 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- ContainsRegex("first parameter shape differs"));
+ HasSubstr("first parameter shape differs"));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
@@ -818,7 +819,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) {
ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {});
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
- ContainsRegex("dot only supports rank"));
+ HasSubstr("dot only supports rank"));
}
// 3D <dot> 2D: error
@@ -827,7 +828,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {});
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
- ContainsRegex("dot only supports rank"));
+ HasSubstr("dot only supports rank"));
}
// vector <dot> vector -> scalar
@@ -940,7 +941,7 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
BinaryOperation::BINOP_ADD, tensor, vec8, {});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("automatic"));
+ HasSubstr("automatic"));
// broadcast_dimension out of bounds for tensor's rank
auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
@@ -954,14 +955,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
BinaryOperation::BINOP_ADD, tensor, vec8, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("broadcast dimension 0 mismatch"));
+ HasSubstr("broadcast dimension 0 mismatch"));
// broadcast_dimensions list too long
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
- ContainsRegex("size of broadcast_dimensions has to match"));
+ HasSubstr("size of broadcast_dimensions has to match"));
// there's a dimension above the rank of the tensor
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
@@ -975,7 +976,7 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
ASSERT_FALSE(inferred_status_error6.ok());
ASSERT_THAT(inferred_status_error6.status().error_message(),
- ContainsRegex("broadcast dimension 0 mismatch"));
+ HasSubstr("broadcast dimension 0 mismatch"));
// The following two tests make sure that broadcasting dimensions are listed
// in a proper (strictly increasing) order, even if the lower-rank array
@@ -984,13 +985,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
ASSERT_FALSE(inferred_status_error7.ok());
ASSERT_THAT(inferred_status_error7.status().error_message(),
- ContainsRegex("broadcast dimensions order is wrong"));
+ HasSubstr("broadcast dimensions order is wrong"));
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
ASSERT_FALSE(inferred_status_error8.ok());
ASSERT_THAT(inferred_status_error8.status().error_message(),
- ContainsRegex("broadcast dimensions order is wrong"));
+ HasSubstr("broadcast dimensions order is wrong"));
}
// Tests for the while instruction with proper shapes.
@@ -1016,7 +1017,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("condition must take 1 arguments"));
+ HasSubstr("condition must take 1 arguments"));
auto bad_shape_2 =
ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
@@ -1024,21 +1025,21 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
- ContainsRegex("body must take 1 arguments"));
+ HasSubstr("body must take 1 arguments"));
auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
auto inferred_status_error3 =
ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("condition must return a boolean"));
+ HasSubstr("condition must return a boolean"));
auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
auto inferred_status_error4 =
ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
- ContainsRegex("parameter of condition and body"));
+ HasSubstr("parameter of condition and body"));
}
// Tests for the concatenate instruction with proper shapes.
@@ -1069,28 +1070,27 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("Concatenate expects at least one argument"));
+ HasSubstr("Concatenate expects at least one argument"));
auto inferred_status_error2 =
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
ASSERT_FALSE(inferred_status_error2.ok());
- ASSERT_THAT(
- inferred_status_error2.status().error_message(),
- ContainsRegex("dimension to concatenate along out of bounds: -1"));
+ ASSERT_THAT(inferred_status_error2.status().error_message(),
+ HasSubstr("dimension to concatenate along out of bounds: -1"));
auto inferred_status_error3 =
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("dimension to concatenate along out of bounds: 1"));
+ HasSubstr("dimension to concatenate along out of bounds: 1"));
Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
{&vector_32_, &tuple}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error4.ok());
- ASSERT_THAT(inferred_status_error4.status().error_message(),
- ContainsRegex(
- "Expected non-tuple argument for operand of concatenation."));
+ ASSERT_THAT(
+ inferred_status_error4.status().error_message(),
+ HasSubstr("Expected non-tuple argument for operand of concatenation."));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
@@ -1098,15 +1098,15 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(
inferred_status_error5.status().error_message(),
- ContainsRegex("cannot concatenate arrays with different element types"));
+ HasSubstr("cannot concatenate arrays with different element types"));
auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
{&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
ASSERT_FALSE(inferred_status_error6.ok());
ASSERT_THAT(inferred_status_error6.status().error_message(),
- ContainsRegex("cannot concatenate arrays that differ in "
- "dimensions other than the one being "
- "concatenated"));
+ HasSubstr("cannot concatenate arrays that differ in "
+ "dimensions other than the one being "
+ "concatenated"));
}
TEST_F(ShapeInferenceTest, Pad) {
@@ -1148,26 +1148,26 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
ShapeInference::InferReverseShape(input_shape, {0, 2});
ASSERT_FALSE(inferred_status_error0.ok());
ASSERT_THAT(inferred_status_error0.status().error_message(),
- ContainsRegex("out-of-bounds"));
+ HasSubstr("out-of-bounds"));
auto inferred_status_error1 =
ShapeInference::InferReverseShape(input_shape, {0, -1});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("out-of-bounds"));
+ HasSubstr("out-of-bounds"));
auto inferred_status_error2 =
ShapeInference::InferReverseShape(input_shape, {0, 0});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
- ContainsRegex("duplicated"));
+ HasSubstr("duplicated"));
Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
auto inferred_status_error3 =
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- ContainsRegex("Expected non-tuple argument"));
+ HasSubstr("Expected non-tuple argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1188,19 +1188,19 @@ TEST_F(ShapeInferenceTest, Call) {
{}, ShapeUtil::MakeProgramShape({f32_}, f32_));
EXPECT_FALSE(inferred_status_error0.ok());
EXPECT_THAT(inferred_status_error0.status().error_message(),
- ContainsRegex("arity must match"));
+ HasSubstr("arity must match"));
auto inferred_status_error1 = ShapeInference::InferCallShape(
{&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
EXPECT_FALSE(inferred_status_error1.ok());
EXPECT_THAT(inferred_status_error1.status().error_message(),
- ContainsRegex("arity must match"));
+ HasSubstr("arity must match"));
auto inferred_status_error2 = ShapeInference::InferCallShape(
{&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
EXPECT_FALSE(inferred_status_error2.ok());
EXPECT_THAT(inferred_status_error2.status().error_message(),
- ContainsRegex("parameter must match argument"));
+ HasSubstr("parameter must match argument"));
}
TEST_F(ShapeInferenceTest, Transpose) {