diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_opcode.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_opcode.cc | 282 |
1 files changed, 33 insertions, 249 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 157d19f5a9..d1eaf35785 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -21,243 +21,22 @@ limitations under the License. namespace xla { string HloOpcodeString(HloOpcode opcode) { - // Note: Do not use ':' in opcode strings. It is used as a special character - // in these places: - // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to - // separate the opcode from the fusion kind - // - In fully qualified names (HloInstruction::FullyQualifiedName()), to - // separate the qualifiers (name of the computation and potentially the - // fusion instruction) from the name switch (opcode) { - case HloOpcode::kAbs: - return "abs"; - case HloOpcode::kAdd: - return "add"; - case HloOpcode::kAnd: - return "and"; - case HloOpcode::kAtan2: - return "atan2"; - case HloOpcode::kBatchNormTraining: - return "batch-norm-training"; - case HloOpcode::kBatchNormInference: - return "batch-norm-inference"; - case HloOpcode::kBatchNormGrad: - return "batch-norm-grad"; - case HloOpcode::kBitcast: - return "bitcast"; - case HloOpcode::kBroadcast: - return "broadcast"; - case HloOpcode::kCall: - return "call"; - case HloOpcode::kClamp: - return "clamp"; - case HloOpcode::kComplex: - return "complex"; - case HloOpcode::kConcatenate: - return "concatenate"; - case HloOpcode::kConstant: - return "constant"; - case HloOpcode::kConvert: - return "convert"; - case HloOpcode::kConvolution: - return "convolution"; - case HloOpcode::kCos: - return "cosine"; - case HloOpcode::kCrossReplicaSum: - return "cross-replica-sum"; - case HloOpcode::kCustomCall: - return "custom-call"; - case HloOpcode::kCopy: - return "copy"; - case HloOpcode::kDivide: - return "divide"; - case HloOpcode::kDot: - return "dot"; - case HloOpcode::kDynamicSlice: - return "dynamic-slice"; - case HloOpcode::kDynamicUpdateSlice: - return "dynamic-update-slice"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kExp: - return "exponential"; - case HloOpcode::kFloor: - return "floor"; - case HloOpcode::kCeil: - return "ceil"; - case HloOpcode::kFusion: - return "fusion"; - case HloOpcode::kGe: - return "greater-than-or-equal-to"; - case HloOpcode::kGetTupleElement: - return "get-tuple-element"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kImag: - return "imag"; - case HloOpcode::kInfeed: - return "infeed"; - case HloOpcode::kIsFinite: - return "is-finite"; - case HloOpcode::kLe: - return "less-than-or-equal-to"; - case HloOpcode::kLog: - return "log"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kMap: - return "map"; - case HloOpcode::kMaximum: - return "maximum"; - case HloOpcode::kMinimum: - return "minimum"; - case HloOpcode::kMultiply: - return "multiply"; - case HloOpcode::kNe: - return "not-equal-to"; - case HloOpcode::kNegate: - return "negate"; - case HloOpcode::kNot: - return "not"; - case HloOpcode::kOr: - return "or"; - case HloOpcode::kOutfeed: - return "outfeed"; - case HloOpcode::kPad: - return "pad"; - case HloOpcode::kParameter: - return "parameter"; - case HloOpcode::kPower: - return "power"; - case HloOpcode::kReal: - return "real"; - case HloOpcode::kRecv: - return "recv"; - case HloOpcode::kReduce: - return "reduce"; - case HloOpcode::kReducePrecision: - return "reduce-precision"; - case HloOpcode::kReduceWindow: - return "reduce-window"; - case HloOpcode::kRemainder: - return "remainder"; - case HloOpcode::kReshape: - return "reshape"; - case HloOpcode::kReverse: - return "reverse"; - case HloOpcode::kRng: - return "rng"; - case HloOpcode::kRoundNearestAfz: - return "round-nearest-afz"; - case HloOpcode::kSelectAndScatter: - return "select-and-scatter"; - case HloOpcode::kSelect: - return "select"; - case HloOpcode::kSend: - return "send"; - case HloOpcode::kShiftLeft: - return "shift-left"; - case HloOpcode::kShiftRightArithmetic: - return "shift-right-arithmetic"; - case HloOpcode::kShiftRightLogical: - return "shift-right-logical"; - case HloOpcode::kSign: - return "sign"; - case HloOpcode::kSin: - return "sine"; - case HloOpcode::kSlice: - return "slice"; - case HloOpcode::kSort: - return "sort"; - case HloOpcode::kSubtract: - return "subtract"; - case HloOpcode::kTanh: - return "tanh"; - case HloOpcode::kTrace: - return "trace"; - case HloOpcode::kTranspose: - return "transpose"; - case HloOpcode::kTuple: - return "tuple"; - case HloOpcode::kWhile: - return "while"; +#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \ + case HloOpcode::enum_name: \ + return opcode_name; + HLO_OPCODE_LIST(CASE_OPCODE_STRING) +#undef CASE_OPCODE_STRING } } StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>( - {{"abs", HloOpcode::kAbs}, - {"add", HloOpcode::kAdd}, - {"and", HloOpcode::kAnd}, - {"batch-norm-training", HloOpcode::kBatchNormTraining}, - {"batch-norm-inference", HloOpcode::kBatchNormInference}, - {"batch-norm-grad", HloOpcode::kBatchNormGrad}, - {"bitcast", HloOpcode::kBitcast}, - {"broadcast", HloOpcode::kBroadcast}, - {"call", HloOpcode::kCall}, - {"clamp", HloOpcode::kClamp}, - {"concatenate", HloOpcode::kConcatenate}, - {"constant", HloOpcode::kConstant}, - {"convert", HloOpcode::kConvert}, - {"convolution", HloOpcode::kConvolution}, - {"cosine", HloOpcode::kCos}, - {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, - {"custom-call", HloOpcode::kCustomCall}, - {"copy", HloOpcode::kCopy}, - {"divide", HloOpcode::kDivide}, - {"dot", HloOpcode::kDot}, - {"dynamic-slice", HloOpcode::kDynamicSlice}, - {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, - {"equal-to", HloOpcode::kEq}, - {"exponential", HloOpcode::kExp}, - {"floor", HloOpcode::kFloor}, - {"ceil", HloOpcode::kCeil}, - {"fusion", HloOpcode::kFusion}, - {"greater-than-or-equal-to", HloOpcode::kGe}, - {"get-tuple-element", HloOpcode::kGetTupleElement}, - {"greater-than", HloOpcode::kGt}, - {"infeed", HloOpcode::kInfeed}, - {"is-finite", HloOpcode::kIsFinite}, - {"less-than-or-equal-to", HloOpcode::kLe}, - {"log", HloOpcode::kLog}, - {"less-than", HloOpcode::kLt}, - {"map", HloOpcode::kMap}, - {"maximum", HloOpcode::kMaximum}, - {"minimum", HloOpcode::kMinimum}, - {"multiply", HloOpcode::kMultiply}, - {"not", HloOpcode::kNot}, - {"not-equal-to", HloOpcode::kNe}, - {"negate", HloOpcode::kNegate}, - {"or", HloOpcode::kOr}, - {"outfeed", HloOpcode::kOutfeed}, - {"pad", HloOpcode::kPad}, - {"parameter", HloOpcode::kParameter}, - {"power", HloOpcode::kPower}, - {"recv", HloOpcode::kRecv}, - {"reduce", HloOpcode::kReduce}, - {"reduce-precision", HloOpcode::kReducePrecision}, - {"reduce-window", HloOpcode::kReduceWindow}, - {"remainder", HloOpcode::kRemainder}, - {"reshape", HloOpcode::kReshape}, - {"reverse", HloOpcode::kReverse}, - {"rng", HloOpcode::kRng}, - {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, - {"select-and-scatter", HloOpcode::kSelectAndScatter}, - {"select", HloOpcode::kSelect}, - {"send", HloOpcode::kSend}, - {"shift-left", HloOpcode::kShiftLeft}, - {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, - {"shift-right-logical", HloOpcode::kShiftRightLogical}, - {"sign", HloOpcode::kSign}, - {"sine", HloOpcode::kSin}, - {"slice", HloOpcode::kSlice}, - {"sort", HloOpcode::kSort}, - {"subtract", HloOpcode::kSubtract}, - {"tanh", HloOpcode::kTanh}, - {"trace", HloOpcode::kTrace}, - {"transpose", HloOpcode::kTranspose}, - {"tuple", HloOpcode::kTuple}, - {"while", HloOpcode::kWhile}}); + static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({ +#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ + {opcode_name, HloOpcode::enum_name}, + HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) +#undef STRING_TO_OPCODE_ENTRY + }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); @@ -265,31 +44,36 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) { return it->second; } +#define CHECK_DEFAULT(property_name, opcode_name) false +#define CHECK_PROPERTY(property_name, opcode_name, value) \ + (value & property_name) +#define RESOLVE(_1, _2, target, ...) target +#define HAS_PROPERTY(property, ...) \ + RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kEq: - case HloOpcode::kNe: - return true; - default: - return false; +#define CASE_IS_COMPARISON(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_COMPARISON) +#undef CASE_IS_COMPARISON } } bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kCall: - case HloOpcode::kConcatenate: - case HloOpcode::kFusion: - case HloOpcode::kMap: - case HloOpcode::kTuple: - return true; - default: - return false; +#define CASE_IS_VARIADIC(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_VARIADIC) +#undef CASE_IS_VARIADIC } } +#undef HAS_PROPERTY +#undef RESOLVE +#undef CHECK_DEFAULT +#undef CHECK_PROPERTY + } // namespace xla |