aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_opcode.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-31 11:54:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-31 12:01:38 -0700
commit35939d2d37a03d95c86708ad0bf52865fbbd3c90 (patch)
treee5467342b116a4b58fdbbb1d8c2394ad66a19fda /tensorflow/compiler/xla/service/hlo_opcode.cc
parent3b845c80d512703a78e6ac567c70ab65801468ef (diff)
[TF:XLA] Fix string to HLO opcode conversion for atan2, complex, imag and real.
Make sure that we can't forget opcodes by auto-generating the conversion functions. Add auto-generated functions to test HLOs for properties (like IsVariadic, IsComparison, etc.) This makes changing HLO more robust and easier because there are fewer places to update when adding or removing an HLO opcode. Also: * Fix IsElementwiseBinary for atan2. * Add a unit test for HLO opcode helpers. * Express IsElementwiseBinary in terms of IsElementwise() and operand_count() to avoid having to keep the two in sync manually. PiperOrigin-RevId: 174069664
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_opcode.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc282
1 files changed, 33 insertions, 249 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 157d19f5a9..d1eaf35785 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -21,243 +21,22 @@ limitations under the License.
namespace xla {
string HloOpcodeString(HloOpcode opcode) {
- // Note: Do not use ':' in opcode strings. It is used as a special character
- // in these places:
- // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
- // separate the opcode from the fusion kind
- // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
- // separate the qualifiers (name of the computation and potentially the
- // fusion instruction) from the name
switch (opcode) {
- case HloOpcode::kAbs:
- return "abs";
- case HloOpcode::kAdd:
- return "add";
- case HloOpcode::kAnd:
- return "and";
- case HloOpcode::kAtan2:
- return "atan2";
- case HloOpcode::kBatchNormTraining:
- return "batch-norm-training";
- case HloOpcode::kBatchNormInference:
- return "batch-norm-inference";
- case HloOpcode::kBatchNormGrad:
- return "batch-norm-grad";
- case HloOpcode::kBitcast:
- return "bitcast";
- case HloOpcode::kBroadcast:
- return "broadcast";
- case HloOpcode::kCall:
- return "call";
- case HloOpcode::kClamp:
- return "clamp";
- case HloOpcode::kComplex:
- return "complex";
- case HloOpcode::kConcatenate:
- return "concatenate";
- case HloOpcode::kConstant:
- return "constant";
- case HloOpcode::kConvert:
- return "convert";
- case HloOpcode::kConvolution:
- return "convolution";
- case HloOpcode::kCos:
- return "cosine";
- case HloOpcode::kCrossReplicaSum:
- return "cross-replica-sum";
- case HloOpcode::kCustomCall:
- return "custom-call";
- case HloOpcode::kCopy:
- return "copy";
- case HloOpcode::kDivide:
- return "divide";
- case HloOpcode::kDot:
- return "dot";
- case HloOpcode::kDynamicSlice:
- return "dynamic-slice";
- case HloOpcode::kDynamicUpdateSlice:
- return "dynamic-update-slice";
- case HloOpcode::kEq:
- return "equal-to";
- case HloOpcode::kExp:
- return "exponential";
- case HloOpcode::kFloor:
- return "floor";
- case HloOpcode::kCeil:
- return "ceil";
- case HloOpcode::kFusion:
- return "fusion";
- case HloOpcode::kGe:
- return "greater-than-or-equal-to";
- case HloOpcode::kGetTupleElement:
- return "get-tuple-element";
- case HloOpcode::kGt:
- return "greater-than";
- case HloOpcode::kImag:
- return "imag";
- case HloOpcode::kInfeed:
- return "infeed";
- case HloOpcode::kIsFinite:
- return "is-finite";
- case HloOpcode::kLe:
- return "less-than-or-equal-to";
- case HloOpcode::kLog:
- return "log";
- case HloOpcode::kLt:
- return "less-than";
- case HloOpcode::kMap:
- return "map";
- case HloOpcode::kMaximum:
- return "maximum";
- case HloOpcode::kMinimum:
- return "minimum";
- case HloOpcode::kMultiply:
- return "multiply";
- case HloOpcode::kNe:
- return "not-equal-to";
- case HloOpcode::kNegate:
- return "negate";
- case HloOpcode::kNot:
- return "not";
- case HloOpcode::kOr:
- return "or";
- case HloOpcode::kOutfeed:
- return "outfeed";
- case HloOpcode::kPad:
- return "pad";
- case HloOpcode::kParameter:
- return "parameter";
- case HloOpcode::kPower:
- return "power";
- case HloOpcode::kReal:
- return "real";
- case HloOpcode::kRecv:
- return "recv";
- case HloOpcode::kReduce:
- return "reduce";
- case HloOpcode::kReducePrecision:
- return "reduce-precision";
- case HloOpcode::kReduceWindow:
- return "reduce-window";
- case HloOpcode::kRemainder:
- return "remainder";
- case HloOpcode::kReshape:
- return "reshape";
- case HloOpcode::kReverse:
- return "reverse";
- case HloOpcode::kRng:
- return "rng";
- case HloOpcode::kRoundNearestAfz:
- return "round-nearest-afz";
- case HloOpcode::kSelectAndScatter:
- return "select-and-scatter";
- case HloOpcode::kSelect:
- return "select";
- case HloOpcode::kSend:
- return "send";
- case HloOpcode::kShiftLeft:
- return "shift-left";
- case HloOpcode::kShiftRightArithmetic:
- return "shift-right-arithmetic";
- case HloOpcode::kShiftRightLogical:
- return "shift-right-logical";
- case HloOpcode::kSign:
- return "sign";
- case HloOpcode::kSin:
- return "sine";
- case HloOpcode::kSlice:
- return "slice";
- case HloOpcode::kSort:
- return "sort";
- case HloOpcode::kSubtract:
- return "subtract";
- case HloOpcode::kTanh:
- return "tanh";
- case HloOpcode::kTrace:
- return "trace";
- case HloOpcode::kTranspose:
- return "transpose";
- case HloOpcode::kTuple:
- return "tuple";
- case HloOpcode::kWhile:
- return "while";
+#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \
+ case HloOpcode::enum_name: \
+ return opcode_name;
+ HLO_OPCODE_LIST(CASE_OPCODE_STRING)
+#undef CASE_OPCODE_STRING
}
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
- static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>(
- {{"abs", HloOpcode::kAbs},
- {"add", HloOpcode::kAdd},
- {"and", HloOpcode::kAnd},
- {"batch-norm-training", HloOpcode::kBatchNormTraining},
- {"batch-norm-inference", HloOpcode::kBatchNormInference},
- {"batch-norm-grad", HloOpcode::kBatchNormGrad},
- {"bitcast", HloOpcode::kBitcast},
- {"broadcast", HloOpcode::kBroadcast},
- {"call", HloOpcode::kCall},
- {"clamp", HloOpcode::kClamp},
- {"concatenate", HloOpcode::kConcatenate},
- {"constant", HloOpcode::kConstant},
- {"convert", HloOpcode::kConvert},
- {"convolution", HloOpcode::kConvolution},
- {"cosine", HloOpcode::kCos},
- {"cross-replica-sum", HloOpcode::kCrossReplicaSum},
- {"custom-call", HloOpcode::kCustomCall},
- {"copy", HloOpcode::kCopy},
- {"divide", HloOpcode::kDivide},
- {"dot", HloOpcode::kDot},
- {"dynamic-slice", HloOpcode::kDynamicSlice},
- {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice},
- {"equal-to", HloOpcode::kEq},
- {"exponential", HloOpcode::kExp},
- {"floor", HloOpcode::kFloor},
- {"ceil", HloOpcode::kCeil},
- {"fusion", HloOpcode::kFusion},
- {"greater-than-or-equal-to", HloOpcode::kGe},
- {"get-tuple-element", HloOpcode::kGetTupleElement},
- {"greater-than", HloOpcode::kGt},
- {"infeed", HloOpcode::kInfeed},
- {"is-finite", HloOpcode::kIsFinite},
- {"less-than-or-equal-to", HloOpcode::kLe},
- {"log", HloOpcode::kLog},
- {"less-than", HloOpcode::kLt},
- {"map", HloOpcode::kMap},
- {"maximum", HloOpcode::kMaximum},
- {"minimum", HloOpcode::kMinimum},
- {"multiply", HloOpcode::kMultiply},
- {"not", HloOpcode::kNot},
- {"not-equal-to", HloOpcode::kNe},
- {"negate", HloOpcode::kNegate},
- {"or", HloOpcode::kOr},
- {"outfeed", HloOpcode::kOutfeed},
- {"pad", HloOpcode::kPad},
- {"parameter", HloOpcode::kParameter},
- {"power", HloOpcode::kPower},
- {"recv", HloOpcode::kRecv},
- {"reduce", HloOpcode::kReduce},
- {"reduce-precision", HloOpcode::kReducePrecision},
- {"reduce-window", HloOpcode::kReduceWindow},
- {"remainder", HloOpcode::kRemainder},
- {"reshape", HloOpcode::kReshape},
- {"reverse", HloOpcode::kReverse},
- {"rng", HloOpcode::kRng},
- {"round-nearest-afz", HloOpcode::kRoundNearestAfz},
- {"select-and-scatter", HloOpcode::kSelectAndScatter},
- {"select", HloOpcode::kSelect},
- {"send", HloOpcode::kSend},
- {"shift-left", HloOpcode::kShiftLeft},
- {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic},
- {"shift-right-logical", HloOpcode::kShiftRightLogical},
- {"sign", HloOpcode::kSign},
- {"sine", HloOpcode::kSin},
- {"slice", HloOpcode::kSlice},
- {"sort", HloOpcode::kSort},
- {"subtract", HloOpcode::kSubtract},
- {"tanh", HloOpcode::kTanh},
- {"trace", HloOpcode::kTrace},
- {"transpose", HloOpcode::kTranspose},
- {"tuple", HloOpcode::kTuple},
- {"while", HloOpcode::kWhile}});
+ static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
+#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
+ {opcode_name, HloOpcode::enum_name},
+ HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
+#undef STRING_TO_OPCODE_ENTRY
+ });
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
@@ -265,31 +44,36 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
return it->second;
}
+#define CHECK_DEFAULT(property_name, opcode_name) false
+#define CHECK_PROPERTY(property_name, opcode_name, value) \
+ (value & property_name)
+#define RESOLVE(_1, _2, target, ...) target
+#define HAS_PROPERTY(property, ...) \
+ RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
+
bool HloOpcodeIsComparison(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kLe:
- case HloOpcode::kLt:
- case HloOpcode::kEq:
- case HloOpcode::kNe:
- return true;
- default:
- return false;
+#define CASE_IS_COMPARISON(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_COMPARISON)
+#undef CASE_IS_COMPARISON
}
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
switch (opcode) {
- case HloOpcode::kCall:
- case HloOpcode::kConcatenate:
- case HloOpcode::kFusion:
- case HloOpcode::kMap:
- case HloOpcode::kTuple:
- return true;
- default:
- return false;
+#define CASE_IS_VARIADIC(enum_name, ...) \
+ case HloOpcode::enum_name: \
+ return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__);
+ HLO_OPCODE_LIST(CASE_IS_VARIADIC)
+#undef CASE_IS_VARIADIC
}
}
+#undef HAS_PROPERTY
+#undef RESOLVE
+#undef CHECK_DEFAULT
+#undef CHECK_PROPERTY
+
} // namespace xla