aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_opcode.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_opcode.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc282
1 files changed, 33 insertions, 249 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 157d19f5a9..d1eaf35785 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -21,243 +21,22 @@ limitations under the License.
namespace xla {
string HloOpcodeString(HloOpcode opcode) {
- // Note: Do not use ':' in opcode strings. It is used as a special character
- // in these places:
- // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
- // separate the opcode from the fusion kind
- // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
- // separate the qualifiers (name of the computation and potentially the
- // fusion instruction) from the name
switch (opcode) {
- case HloOpcode::kAbs:
- return "abs";
- case HloOpcode::kAdd:
- return "add";
- case HloOpcode::kAnd:
- return "and";
- case HloOpcode::kAtan2:
- return "atan2";
- case HloOpcode::kBatchNormTraining:
- return "batch-norm-training";
- case HloOpcode::kBatchNormInference:
- return "batch-norm-inference";
- case HloOpcode::kBatchNormGrad:
- return "batch-norm-grad";
- case HloOpcode::kBitcast:
- return "bitcast";
- case HloOpcode::kBroadcast:
- return "broadcast";
- case HloOpcode::kCall:
- return "call";
- case HloOpcode::kClamp:
- return "clamp";
- case HloOpcode::kComplex:
- return "complex";
- case HloOpcode::kConcatenate:
- return "concatenate";
- case HloOpcode::kConstant:
- return "constant";
- case HloOpcode::kConvert:
- return "convert";
- case HloOpcode::kConvolution:
- return "convolution";
- case HloOpcode::kCos:
- return "cosine";
- case HloOpcode::kCrossReplicaSum:
- return "cross-replica-sum";
- case HloOpcode::kCustomCall:
- return "custom-call";
- case HloOpcode::kCopy:
- return "copy";
- case HloOpcode::kDivide:
- return "divide";
- case HloOpcode::kDot:
- return "dot";
- case HloOpcode::kDynamicSlice:
- return "dynamic-slice";
- case HloOpcode::kDynamicUpdateSlice:
- return "dynamic-update-slice";
- case HloOpcode::kEq:
- return "equal-to";
- case HloOpcode::kExp:
- return "exponential";
- case HloOpcode::kFloor:
- return "floor";
- case HloOpcode::kCeil:
- return "ceil";
- case HloOpcode::kFusion:
- return "fusion";
- case HloOpcode::kGe:
- return "greater-than-or-equal-to";
- case HloOpcode::kGetTupleElement:
- return "get-tuple-element";
- case HloOpcode::kGt:
- return "greater-than";
- case HloOpcode::kImag:
- return "imag";
- case HloOpcode::kInfeed:
- return "infeed";
- case HloOpcode::kIsFinite:
- return "is-finite";
- case HloOpcode::kLe:
- return "less-than-or-equal-to";
- case HloOpcode::kLog:
- return "log";
- case HloOpcode::kLt:
- return "less-than";
- case HloOpcode::kMap:
- return "map";
- case HloOpcode::kMaximum:
- return "maximum";
- case HloOpcode::kMinimum:
- return "minimum";
- case HloOpcode::kMultiply:
- return "multiply";
- case HloOpcode::kNe:
- return "not-equal-to";
- case HloOpcode::kNegate:
- return "negate";
- case HloOpcode::kNot:
- return "not";
- case HloOpcode::kOr:
- return "or";
- case HloOpcode::kOutfeed:
- return "outfeed";
- case HloOpcode::kPad:
- return "pad";
- case HloOpcode::kParameter:
- return "parameter";
- case HloOpcode::kPower:
- return "power";
- case HloOpcode::kReal:
- return "real";
- case HloOpcode::kRecv:
- return "recv";
- case HloOpcode::kReduce:
- return "reduce";
- case HloOpcode::kReducePrecision:
- return "reduce-precision";
- case HloOpcode::kReduceWindow:
- return "reduce-window";
- case HloOpcode::kRemainder:
- return "remainder";
- case HloOpcode::kReshape:
- return "reshape";
- case HloOpcode::kReverse:
- return "reverse";
- case HloOpcode::kRng:
- return "rng";
- case HloOpcode::kRoundNearestAfz:
- return "round-nearest-afz";
- case HloOpcode::kSelectAndScatter:
- return "select-and-scatter";
- case HloOpcode::kSelect:
- return "select";
- case HloOpcode::kSend:
- return "send";
- case HloOpcode::kShiftLeft:
- return "shift-left";
- case HloOpcode::kShiftRightArithmetic:
- return "shift-right-arithmetic";
- case HloOpcode::kShiftRightLogical:
- return "shift-right-logical";
- case HloOpcode::kSign:
- return "sign";
- case HloOpcode::kSin:
- return "sine";
- case HloOpcode::kSlice:
- return "slice";
- case HloOpcode::kSort:
- return "sort";
- case HloOpcode::kSubtract:
- return "subtract";
- case HloOpcode::kTanh:
- return "tanh";
- case HloOpcode::kTrace:
- return "trace";
- case HloOpcode::kTranspose:
- return "transpose";
- case HloOpcode::kTuple:
- return "tuple";
- case HloOpcode::kWhile:
- return "while";
+#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \
+ case HloOpcode::enum_name: \
+ return opcode_name;
+ HLO_OPCODE_LIST(CASE_OPCODE_STRING)
+#undef CASE_OPCODE_STRING
}
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
- static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>(
- {{"abs", HloOpcode::kAbs},
- {"add", HloOpcode::kAdd},
- {"and", HloOpcode::kAnd},
- {"batch-norm-training", HloOpcode::kBatchNormTraining},
- {"batch-norm-inference", HloOpcode::kBatchNormInference},
- {"batch-norm-grad", HloOpcode::kBatchNormGrad},
- {"bitcast", HloOpcode::kBitcast},
- {"broadcast", HloOpcode::kBroadcast},
- {"call", HloOpcode::kCall},
- {"clamp", HloOpcode::kClamp},
- {"concatenate", HloOpcode::kConcatenate},
- {"constant", HloOpcode::kConstant},
- {"convert", HloOpcode::kConvert},
- {"convolution", HloOpcode::kConvolution},
- {"cosine", HloOpcode::kCos},
- {"cross-replica-sum", HloOpcode::kCrossReplicaSum},
- {"custom-call", HloOpcode::kCustomCall},
- {"copy", HloOpcode::kCopy},
- {"divide", HloOpcode::kDivide},
- {"dot", HloOpcode::kDot},
- {"dynamic-slice", HloOpcode::kDynamicSlice},
- {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice},
- {"equal-to", HloOpcode::kEq},
- {"exponential", HloOpcode::kExp},
- {"floor", HloOpcode::kFloor},
- {"ceil", HloOpcode::kCeil},
- {"fusion", HloOpcode::kFusion},
- {"greater-than-or-equal-to", HloOpcode::kGe},
- {"get-tuple-element", HloOpcode::kGetTupleElement},
- {"greater-than", HloOpcode::kGt},
- {"infeed", HloOpcode::kInfeed},
- {"is-finite", HloOpcode::kIsFinite},
- {"less-than-or-equal-to", HloOpcode::kLe},
- {"log", HloOpcode::kLog},
- {"less-than", HloOpcode::kLt},
- {"map", HloOpcode::kMap},
- {"maximum", HloOpcode::kMaximum},
- {"minimum", HloOpcode::kMinimum},
- {"multiply", HloOpcode::kMultiply},
- {"not", HloOpcode::kNot},
- {"not-equal-to", HloOpcode::kNe},
- {"negate", HloOpcode::kNegate},
- {"or", HloOpcode::kOr},
- {"outfeed", HloOpcode::kOutfeed},
- {"pad", HloOpcode::kPad},
- {"parameter", HloOpcode::kParameter},
- {"power", HloOpcode::kPower},
- {"recv", HloOpcode::kRecv},
- {"reduce", HloOpcode::kReduce},
- {"reduce-precision", HloOpcode::kReducePrecision},
- {"reduce-window", HloOpcode::kReduceWindow},
- {"remainder", HloOpcode::kRemainder},
- {"reshape", HloOpcode::kReshape},
- {"reverse", HloOpcode::kReverse},
- {"rng", HloOpcode::kRng},
- {"round-nearest-afz", HloOpcode::kRoundNearestAfz},
- {"select-and-scatter", HloOpcode::kSelectAndScatter},
- {"select", HloOpcode::kSelect},
- {"send", HloOpcode::kSend},
- {"shift-left", HloOpcode::kShiftLeft},
- {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic},
- {"shift-right-logical", HloOpcode::kShiftRightLogical},
- {"sign", HloOpcode::kSign},
- {"sine", HloOpcode::kSin},
- {"slice", HloOpcode::kSlice},
- {"sort", HloOpcode::kSort},
- {"subtract", HloOpcode::kSubtract},
- {"tanh", HloOpcode::kTanh},
- {"trace", HloOpcode::kTrace},
- {"transpose", HloOpcode::kTranspose},
- {"tuple", HloOpcode::kTuple},
- {"while", HloOpcode::kWhile}});
+ static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
+#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
+ {opcode_name, HloOpcode::enum_name},
+ HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
+#undef STRING_TO_OPCODE_ENTRY
+ });
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
@@ -265,31 +44,36 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
return it->second;
}
+#define CHECK_DEFAULT(property_name, opcode_name) false
+#define CHECK_PROPERTY(property_name, opcode_name, value) \
+ (value & property_name)
+#define RESOLVE(_1, _2, target, ...) target
+#define HAS_PROPERTY(property, ...) \
+ RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
+
bool HloOpcodeIsComparison(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kLe:
- case HloOpcode::kLt:
- case HloOpcode::kEq:
- case HloOpcode::kNe:
- return true;
- default:
- return false;
+#define CASE_IS_COMPARISON(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_COMPARISON)
+#undef CASE_IS_COMPARISON
}
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kCall:
- case HloOpcode::kConcatenate:
- case HloOpcode::kFusion:
- case HloOpcode::kMap:
- case HloOpcode::kTuple:
- return true;
- default:
- return false;
+#define CASE_IS_VARIADIC(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_VARIADIC)
+#undef CASE_IS_VARIADIC
}
}
+#undef HAS_PROPERTY
+#undef RESOLVE
+#undef CHECK_DEFAULT
+#undef CHECK_PROPERTY
+
} // namespace xla