aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-31 11:54:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-31 12:01:38 -0700
commit35939d2d37a03d95c86708ad0bf52865fbbd3c90 (patch)
treee5467342b116a4b58fdbbb1d8c2394ad66a19fda /tensorflow/compiler
parent3b845c80d512703a78e6ac567c70ab65801468ef (diff)
[TF:XLA] Fix string to HLO opcode conversion for atan2, complex, imag and real.
Make sure that we can't forget opcodes by auto-generating the conversion functions. Add auto-generated functions to test HLOs for properties (like IsVariadic, IsComparison, etc.) This makes changing HLO more robust and easier because there are fewer places to update when adding or removing an HLO opcode. Also: * Fix IsElementwiseBinary for atan2. * Add a unit test for HLO opcode helpers. * Express IsElementwiseBinary in terms of IsElementwise() and operand_count() to avoid having to keep the two in sync manually. PiperOrigin-RevId: 174069664
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc282
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h185
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc41
4 files changed, 186 insertions, 355 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e6a4f68fb3..ecf8cd4065 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2514,33 +2514,7 @@ std::vector<int64> HloInstruction::OperandIndices(
}
bool HloInstruction::IsElementwiseBinary() const {
- switch (opcode_) {
- // Binary elementwise operations. If you update this, please update
- // IsElementwise() accordingly.
- case HloOpcode::kAdd:
- case HloOpcode::kComplex:
- case HloOpcode::kDivide:
- case HloOpcode::kEq:
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kLe:
- case HloOpcode::kLt:
- case HloOpcode::kMaximum:
- case HloOpcode::kMinimum:
- case HloOpcode::kMultiply:
- case HloOpcode::kNe:
- case HloOpcode::kPower:
- case HloOpcode::kRemainder:
- case HloOpcode::kSubtract:
- case HloOpcode::kAnd:
- case HloOpcode::kOr:
- case HloOpcode::kShiftLeft:
- case HloOpcode::kShiftRightArithmetic:
- case HloOpcode::kShiftRightLogical:
- return true;
- default:
- return false;
- }
+ return IsElementwise() && operand_count() == 2;
}
bool HloInstruction::IsElementwise() const {
@@ -2551,7 +2525,6 @@ bool HloInstruction::IsElementwise() const {
// Unary elementwise operations.
case HloOpcode::kAbs:
- case HloOpcode::kAtan2:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
@@ -2569,11 +2542,12 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
+ CHECK_EQ(1, operand_count());
return true;
// Binary elementwise operations, the same as in IsElementwiseBinary().
- // If you update this, please update IsElementwiseBinary() accordingly.
case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
@@ -2593,6 +2567,7 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
+ CHECK_EQ(2, operand_count());
return true;
// Ternary elementwise operations.
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 157d19f5a9..d1eaf35785 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -21,243 +21,22 @@ limitations under the License.
namespace xla {
string HloOpcodeString(HloOpcode opcode) {
- // Note: Do not use ':' in opcode strings. It is used as a special character
- // in these places:
- // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
- // separate the opcode from the fusion kind
- // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
- // separate the qualifiers (name of the computation and potentially the
- // fusion instruction) from the name
switch (opcode) {
- case HloOpcode::kAbs:
- return "abs";
- case HloOpcode::kAdd:
- return "add";
- case HloOpcode::kAnd:
- return "and";
- case HloOpcode::kAtan2:
- return "atan2";
- case HloOpcode::kBatchNormTraining:
- return "batch-norm-training";
- case HloOpcode::kBatchNormInference:
- return "batch-norm-inference";
- case HloOpcode::kBatchNormGrad:
- return "batch-norm-grad";
- case HloOpcode::kBitcast:
- return "bitcast";
- case HloOpcode::kBroadcast:
- return "broadcast";
- case HloOpcode::kCall:
- return "call";
- case HloOpcode::kClamp:
- return "clamp";
- case HloOpcode::kComplex:
- return "complex";
- case HloOpcode::kConcatenate:
- return "concatenate";
- case HloOpcode::kConstant:
- return "constant";
- case HloOpcode::kConvert:
- return "convert";
- case HloOpcode::kConvolution:
- return "convolution";
- case HloOpcode::kCos:
- return "cosine";
- case HloOpcode::kCrossReplicaSum:
- return "cross-replica-sum";
- case HloOpcode::kCustomCall:
- return "custom-call";
- case HloOpcode::kCopy:
- return "copy";
- case HloOpcode::kDivide:
- return "divide";
- case HloOpcode::kDot:
- return "dot";
- case HloOpcode::kDynamicSlice:
- return "dynamic-slice";
- case HloOpcode::kDynamicUpdateSlice:
- return "dynamic-update-slice";
- case HloOpcode::kEq:
- return "equal-to";
- case HloOpcode::kExp:
- return "exponential";
- case HloOpcode::kFloor:
- return "floor";
- case HloOpcode::kCeil:
- return "ceil";
- case HloOpcode::kFusion:
- return "fusion";
- case HloOpcode::kGe:
- return "greater-than-or-equal-to";
- case HloOpcode::kGetTupleElement:
- return "get-tuple-element";
- case HloOpcode::kGt:
- return "greater-than";
- case HloOpcode::kImag:
- return "imag";
- case HloOpcode::kInfeed:
- return "infeed";
- case HloOpcode::kIsFinite:
- return "is-finite";
- case HloOpcode::kLe:
- return "less-than-or-equal-to";
- case HloOpcode::kLog:
- return "log";
- case HloOpcode::kLt:
- return "less-than";
- case HloOpcode::kMap:
- return "map";
- case HloOpcode::kMaximum:
- return "maximum";
- case HloOpcode::kMinimum:
- return "minimum";
- case HloOpcode::kMultiply:
- return "multiply";
- case HloOpcode::kNe:
- return "not-equal-to";
- case HloOpcode::kNegate:
- return "negate";
- case HloOpcode::kNot:
- return "not";
- case HloOpcode::kOr:
- return "or";
- case HloOpcode::kOutfeed:
- return "outfeed";
- case HloOpcode::kPad:
- return "pad";
- case HloOpcode::kParameter:
- return "parameter";
- case HloOpcode::kPower:
- return "power";
- case HloOpcode::kReal:
- return "real";
- case HloOpcode::kRecv:
- return "recv";
- case HloOpcode::kReduce:
- return "reduce";
- case HloOpcode::kReducePrecision:
- return "reduce-precision";
- case HloOpcode::kReduceWindow:
- return "reduce-window";
- case HloOpcode::kRemainder:
- return "remainder";
- case HloOpcode::kReshape:
- return "reshape";
- case HloOpcode::kReverse:
- return "reverse";
- case HloOpcode::kRng:
- return "rng";
- case HloOpcode::kRoundNearestAfz:
- return "round-nearest-afz";
- case HloOpcode::kSelectAndScatter:
- return "select-and-scatter";
- case HloOpcode::kSelect:
- return "select";
- case HloOpcode::kSend:
- return "send";
- case HloOpcode::kShiftLeft:
- return "shift-left";
- case HloOpcode::kShiftRightArithmetic:
- return "shift-right-arithmetic";
- case HloOpcode::kShiftRightLogical:
- return "shift-right-logical";
- case HloOpcode::kSign:
- return "sign";
- case HloOpcode::kSin:
- return "sine";
- case HloOpcode::kSlice:
- return "slice";
- case HloOpcode::kSort:
- return "sort";
- case HloOpcode::kSubtract:
- return "subtract";
- case HloOpcode::kTanh:
- return "tanh";
- case HloOpcode::kTrace:
- return "trace";
- case HloOpcode::kTranspose:
- return "transpose";
- case HloOpcode::kTuple:
- return "tuple";
- case HloOpcode::kWhile:
- return "while";
+#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \
+ case HloOpcode::enum_name: \
+ return opcode_name;
+ HLO_OPCODE_LIST(CASE_OPCODE_STRING)
+#undef CASE_OPCODE_STRING
}
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
- static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>(
- {{"abs", HloOpcode::kAbs},
- {"add", HloOpcode::kAdd},
- {"and", HloOpcode::kAnd},
- {"batch-norm-training", HloOpcode::kBatchNormTraining},
- {"batch-norm-inference", HloOpcode::kBatchNormInference},
- {"batch-norm-grad", HloOpcode::kBatchNormGrad},
- {"bitcast", HloOpcode::kBitcast},
- {"broadcast", HloOpcode::kBroadcast},
- {"call", HloOpcode::kCall},
- {"clamp", HloOpcode::kClamp},
- {"concatenate", HloOpcode::kConcatenate},
- {"constant", HloOpcode::kConstant},
- {"convert", HloOpcode::kConvert},
- {"convolution", HloOpcode::kConvolution},
- {"cosine", HloOpcode::kCos},
- {"cross-replica-sum", HloOpcode::kCrossReplicaSum},
- {"custom-call", HloOpcode::kCustomCall},
- {"copy", HloOpcode::kCopy},
- {"divide", HloOpcode::kDivide},
- {"dot", HloOpcode::kDot},
- {"dynamic-slice", HloOpcode::kDynamicSlice},
- {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice},
- {"equal-to", HloOpcode::kEq},
- {"exponential", HloOpcode::kExp},
- {"floor", HloOpcode::kFloor},
- {"ceil", HloOpcode::kCeil},
- {"fusion", HloOpcode::kFusion},
- {"greater-than-or-equal-to", HloOpcode::kGe},
- {"get-tuple-element", HloOpcode::kGetTupleElement},
- {"greater-than", HloOpcode::kGt},
- {"infeed", HloOpcode::kInfeed},
- {"is-finite", HloOpcode::kIsFinite},
- {"less-than-or-equal-to", HloOpcode::kLe},
- {"log", HloOpcode::kLog},
- {"less-than", HloOpcode::kLt},
- {"map", HloOpcode::kMap},
- {"maximum", HloOpcode::kMaximum},
- {"minimum", HloOpcode::kMinimum},
- {"multiply", HloOpcode::kMultiply},
- {"not", HloOpcode::kNot},
- {"not-equal-to", HloOpcode::kNe},
- {"negate", HloOpcode::kNegate},
- {"or", HloOpcode::kOr},
- {"outfeed", HloOpcode::kOutfeed},
- {"pad", HloOpcode::kPad},
- {"parameter", HloOpcode::kParameter},
- {"power", HloOpcode::kPower},
- {"recv", HloOpcode::kRecv},
- {"reduce", HloOpcode::kReduce},
- {"reduce-precision", HloOpcode::kReducePrecision},
- {"reduce-window", HloOpcode::kReduceWindow},
- {"remainder", HloOpcode::kRemainder},
- {"reshape", HloOpcode::kReshape},
- {"reverse", HloOpcode::kReverse},
- {"rng", HloOpcode::kRng},
- {"round-nearest-afz", HloOpcode::kRoundNearestAfz},
- {"select-and-scatter", HloOpcode::kSelectAndScatter},
- {"select", HloOpcode::kSelect},
- {"send", HloOpcode::kSend},
- {"shift-left", HloOpcode::kShiftLeft},
- {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic},
- {"shift-right-logical", HloOpcode::kShiftRightLogical},
- {"sign", HloOpcode::kSign},
- {"sine", HloOpcode::kSin},
- {"slice", HloOpcode::kSlice},
- {"sort", HloOpcode::kSort},
- {"subtract", HloOpcode::kSubtract},
- {"tanh", HloOpcode::kTanh},
- {"trace", HloOpcode::kTrace},
- {"transpose", HloOpcode::kTranspose},
- {"tuple", HloOpcode::kTuple},
- {"while", HloOpcode::kWhile}});
+ static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
+#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
+ {opcode_name, HloOpcode::enum_name},
+ HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
+#undef STRING_TO_OPCODE_ENTRY
+ });
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
@@ -265,31 +44,36 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
return it->second;
}
+#define CHECK_DEFAULT(property_name, opcode_name) false
+#define CHECK_PROPERTY(property_name, opcode_name, value) \
+ (value & property_name)
+#define RESOLVE(_1, _2, target, ...) target
+#define HAS_PROPERTY(property, ...) \
+ RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
+
bool HloOpcodeIsComparison(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kLe:
- case HloOpcode::kLt:
- case HloOpcode::kEq:
- case HloOpcode::kNe:
- return true;
- default:
- return false;
+#define CASE_IS_COMPARISON(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_COMPARISON)
+#undef CASE_IS_COMPARISON
}
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kCall:
- case HloOpcode::kConcatenate:
- case HloOpcode::kFusion:
- case HloOpcode::kMap:
- case HloOpcode::kTuple:
- return true;
- default:
- return false;
+#define CASE_IS_VARIADIC(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_VARIADIC)
+#undef CASE_IS_VARIADIC
}
}
+#undef HAS_PROPERTY
+#undef RESOLVE
+#undef CHECK_DEFAULT
+#undef CHECK_PROPERTY
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 07c2d26f00..d68fc20321 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -28,83 +28,112 @@ namespace xla {
// present in the XLA service protobuf.
//
// See the XLA documentation for the semantics of each opcode.
+//
+// Each entry has the format:
+// (enum_name, opcode_name)
+// or
+// (enum_name, opcode_name, p1 | p2 | ...)
+//
+// with p1, p2, ... are members of HloOpcodeProperty. They are combined
+// using bitwise-or.
+//
+// Note: Do not use ':' in opcode names. It is used as a special character
+// in these places:
+// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
+// separate the opcode from the fusion kind
+// - In fully qualified names (HloInstruction::FullyQualifiedName()), to
+// separate the qualifiers (name of the computation and potentially the
+// fusion instruction) from the name
+#define HLO_OPCODE_LIST(V) \
+ V(kAbs, "abs") \
+ V(kAdd, "add") \
+ V(kAtan2, "atan2") \
+ V(kBatchNormGrad, "batch-norm-grad") \
+ V(kBatchNormInference, "batch-norm-inference") \
+ V(kBatchNormTraining, "batch-norm-training") \
+ V(kBitcast, "bitcast") \
+ V(kBroadcast, "broadcast") \
+ V(kCall, "call", kHloOpcodeIsVariadic) \
+ V(kCeil, "ceil") \
+ V(kClamp, "clamp") \
+ V(kComplex, "complex") \
+ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
+ V(kConstant, "constant") \
+ V(kConvert, "convert") \
+ V(kConvolution, "convolution") \
+ V(kCopy, "copy") \
+ V(kCos, "cosine") \
+ V(kCrossReplicaSum, "cross-replica-sum") \
+ V(kCustomCall, "custom-call") \
+ V(kDivide, "divide") \
+ V(kDot, "dot") \
+ V(kDynamicSlice, "dynamic-slice") \
+ V(kDynamicUpdateSlice, "dynamic-update-slice") \
+ V(kEq, "equal-to", kHloOpcodeIsComparison) \
+ V(kExp, "exponential") \
+ V(kFloor, "floor") \
+ V(kFusion, "fusion", kHloOpcodeIsVariadic) \
+ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
+ V(kGetTupleElement, "get-tuple-element") \
+ V(kGt, "greater-than", kHloOpcodeIsComparison) \
+ V(kImag, "imag") \
+ V(kInfeed, "infeed") \
+ V(kIsFinite, "is-finite") \
+ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \
+ V(kLog, "log") \
+ V(kAnd, "and") \
+ V(kNot, "not") \
+ V(kOr, "or") \
+ V(kLt, "less-than", kHloOpcodeIsComparison) \
+ V(kMap, "map", kHloOpcodeIsVariadic) \
+ V(kMaximum, "maximum") \
+ V(kMinimum, "minimum") \
+ V(kMultiply, "multiply") \
+ V(kNe, "not-equal-to", kHloOpcodeIsComparison) \
+ V(kNegate, "negate") \
+ V(kOutfeed, "outfeed") \
+ V(kPad, "pad") \
+ V(kParameter, "parameter") \
+ V(kPower, "power") \
+ V(kReal, "real") \
+ V(kRecv, "recv") \
+ V(kReduce, "reduce") \
+ V(kReducePrecision, "reduce-precision") \
+ V(kReduceWindow, "reduce-window") \
+ V(kRemainder, "remainder") \
+ V(kReshape, "reshape") \
+ V(kReverse, "reverse") \
+ V(kRng, "rng") \
+ V(kRoundNearestAfz, "round-nearest-afz") \
+ V(kSelect, "select") \
+ V(kSelectAndScatter, "select-and-scatter") \
+ V(kSend, "send") \
+ V(kShiftLeft, "shift-left") \
+ V(kShiftRightArithmetic, "shift-right-arithmetic") \
+ V(kShiftRightLogical, "shift-right-logical") \
+ V(kSign, "sign") \
+ V(kSin, "sine") \
+ V(kSlice, "slice") \
+ V(kSort, "sort") \
+ V(kSubtract, "subtract") \
+ V(kTanh, "tanh") \
+ V(kTrace, "trace") \
+ V(kTranspose, "transpose") \
+ V(kTuple, "tuple", kHloOpcodeIsVariadic) \
+ V(kWhile, "while")
+
enum class HloOpcode {
- kAbs,
- kAdd,
- kAtan2,
- kBatchNormGrad,
- kBatchNormInference,
- kBatchNormTraining,
- kBitcast,
- kBroadcast,
- kCall,
- kCeil,
- kClamp,
- kComplex,
- kConcatenate,
- kConstant,
- kConvert,
- kConvolution,
- kCopy,
- kCos,
- kCrossReplicaSum,
- kCustomCall,
- kDivide,
- kDot,
- kDynamicSlice,
- kDynamicUpdateSlice,
- kEq,
- kExp,
- kFloor,
- kFusion,
- kGe,
- kGetTupleElement,
- kGt,
- kImag,
- kInfeed,
- kIsFinite,
- kLe,
- kLog,
- kAnd,
- kNot,
- kOr,
- kLt,
- kMap,
- kMaximum,
- kMinimum,
- kMultiply,
- kNe,
- kNegate,
- kOutfeed,
- kPad,
- kParameter,
- kPower,
- kReal,
- kRecv,
- kReduce,
- kReducePrecision,
- kReduceWindow,
- kRemainder,
- kReshape,
- kReverse,
- kRng,
- kRoundNearestAfz,
- kSelect,
- kSelectAndScatter,
- kSend,
- kShiftLeft,
- kShiftRightArithmetic,
- kShiftRightLogical,
- kSign,
- kSin,
- kSlice,
- kSort,
- kSubtract,
- kTanh,
- kTrace,
- kTranspose,
- kTuple,
- kWhile,
+#define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name,
+ HLO_OPCODE_LIST(DECLARE_ENUM)
+#undef DECLARE_ENUM
+};
+
+// List of properties associated with opcodes.
+// Properties are defined as increasing powers of two, so that we can use
+// bitwise-or to combine properties, and bitwise-and to test for them.
+enum HloOpcodeProperty {
+ kHloOpcodeIsComparison = 1 << 0,
+ kHloOpcodeIsVariadic = 1 << 1,
};
// Returns a string representation of the opcode.
@@ -125,7 +154,9 @@ bool HloOpcodeIsVariadic(HloOpcode opcode);
// Returns the number of HloOpcode values.
inline const uint32_t HloOpcodeCount() {
- return static_cast<uint32_t>(HloOpcode::kWhile) + 1;
+#define HLO_COUNT_ONE(...) +1
+#define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE)
+ return HLO_XLIST_LENGTH(HLO_OPCODE_LIST);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
index 892c89f9df..cd2ce5c69f 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -26,5 +26,46 @@ TEST(HloOpcodeTest, StringifyMultiply) {
ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply));
}
+TEST(HloOpcodeTest, OpcodeProperties) {
+ // Test counting macro.
+#define SOME_LIST(X) \
+ X(One) \
+ X(Two) \
+ X(Three)
+ EXPECT_EQ(3, HLO_XLIST_LENGTH(SOME_LIST));
+#undef SOME_LIST
+
+ for (int i = 0; i < HloOpcodeCount(); ++i) {
+ auto opcode = static_cast<HloOpcode>(i);
+ // Test round-trip conversion to and from string.
+ EXPECT_EQ(opcode, StringToHloOpcode(HloOpcodeString(opcode)).ValueOrDie());
+
+ // Test some properties.
+ switch (opcode) {
+ case HloOpcode::kEq:
+ case HloOpcode::kNe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLt:
+ case HloOpcode::kGe:
+ case HloOpcode::kLe:
+ EXPECT_TRUE(HloOpcodeIsComparison(opcode));
+ break;
+ default:
+ EXPECT_FALSE(HloOpcodeIsComparison(opcode));
+ }
+ switch (opcode) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kFusion:
+ case HloOpcode::kMap:
+ case HloOpcode::kTuple:
+ EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
+ break;
+ default:
+ EXPECT_FALSE(HloOpcodeIsVariadic(opcode));
+ }
+ }
+}
+
} // namespace
} // namespace xla