aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 17:39:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 17:42:02 -0700
commit80459fe0fdcb86b286311559c65a7ec43525e278 (patch)
treeccdb1d36d150c4b8a261c9c48d2a0b2ef7fac637 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentf81f62a0d35ccf7c4e83e09510447d93933ef87e (diff)
Cleanup shape_inference.
PiperOrigin-RevId: 199876297
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc144
1 files changed, 70 insertions, 74 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 0e61994a78..6d017dffe2 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status = ShapeInference::InferUnaryOpShape(
- UnaryOperation::UNOP_NEGATE, matrix_shape);
+ auto inferred_status =
+ ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
}
@@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple);
+ HloOpcode::kSelect, pred_, tuple, tuple);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
@@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectBadShapes) {
auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Operands to select must be the same shape"));
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("pred operand must have PRED"));
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
- matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
+ matrix_64_48_);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("with non-scalar predicate with dimensionality"));
// Tuples have a TUPLE element type and cannot be the pred of a select.
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}),
+ HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
ShapeUtil::MakeTupleShape({f32_, f32_}),
ShapeUtil::MakeTupleShape({f32_, f32_}));
ASSERT_FALSE(inferred_status_error4.ok());
@@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
- matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampAllScalar) {
- auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
+ auto inferred_status =
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampBadShapes) {
// Type mismatch
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
- .ok());
- // Dimension mismatch
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_64_, vector_32_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_64_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_32_, vector_64_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
.ok());
- // Dimension mismatch, where one operand is a scalar
+ // Dimension mismatch
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
+ HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
+ .ok());
+ // Dimension mismatch, where one operand is a scalar
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, vector_32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, f32_, vector_32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
+ vector_64_, vector_32_)
.ok());
}
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
const tensorflow::gtl::ArraySlice<int64>& bcast) {
- return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
- lhs, rhs, bcast);
+ return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
+ bcast);
};
// Inputs must be FP.
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
@@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) {
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
- StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
- VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
+ StatusOr<Shape> result =
+ ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
ASSERT_IS_OK(result.status());
ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
ShapeUtil::MakeTupleShape({s32_, f32_})));
@@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) {
TEST_F(ShapeInferenceTest, InferPowShape) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
- auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {});
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ HloOpcode::kPower, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
}
@@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
- auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {1});
+ auto inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {0});
+ auto inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
ASSERT_FALSE(inferred_status_mismatch.ok());
- inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {0});
+ inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {1});
+ inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2});
+ HloOpcode::kAdd, cube, matrix8_4, {1, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2});
+ HloOpcode::kAdd, cube, matrix16_4, {0, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1});
+ HloOpcode::kAdd, cube, matrix16_8, {0, 1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
}
@@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
// "magical" broadcast rejected
- auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {});
+ auto inferred_status_error1 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Automatic"));
// broadcast_dimension out of bounds for tensor's rank
- auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {3});
+ auto inferred_status_error2 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
ContainsRegex("Broadcast dimension number .* too large"));
// broadcast_dimension doesn't match corresponding dimension
- auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {0});
+ auto inferred_status_error3 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("Broadcast dimension 0 mismatch"));
// broadcast_dimensions list too long
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
+ HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
HasSubstr("broadcast_dimensions has to match"));
// there's a dimension above the rank of the tensor
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
+ HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
ContainsRegex("dimension number .* too large"));
// broadcasting dimensions don't match in this order
auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
+ HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
ASSERT_FALSE(inferred_status_error6.ok());
ASSERT_THAT(inferred_status_error6.status().error_message(),
HasSubstr("dimension 0 mismatch"));
@@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
// in a proper (strictly increasing) order, even if the lower-rank array
// matches the higher-rank array in many different ways.
auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
ASSERT_FALSE(inferred_status_error7.ok());
ASSERT_THAT(inferred_status_error7.status().error_message(),
HasSubstr("dimensions order is wrong"));
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
ASSERT_FALSE(inferred_status_error8.ok());
ASSERT_THAT(inferred_status_error8.status().error_message(),
HasSubstr("dimensions order is wrong"));