From 35939d2d37a03d95c86708ad0bf52865fbbd3c90 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 31 Oct 2017 11:54:57 -0700 Subject: [TF:XLA] Fix string to HLO opcode conversion for atan2, complex, imag and real. Make sure that we can't forget opcodes by auto-generating the conversion functions. Add auto-generated functions to test HLOs for properties (like IsVariadic, IsComparison, etc.) This makes changing HLO more robust and easier because there are fewer places to update when adding or removing an HLO opcode. Also: * Fix IsElementwiseBinary for atan2. * Add a unit test for HLO opcode helpers. * Express IsElementwiseBinary in terms of IsElementwise() and operand_count() to avoid having to keep the two in sync manually. PiperOrigin-RevId: 174069664 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 33 +-- tensorflow/compiler/xla/service/hlo_opcode.cc | 282 +++------------------ tensorflow/compiler/xla/service/hlo_opcode.h | 185 ++++++++------ tensorflow/compiler/xla/service/hlo_opcode_test.cc | 41 +++ 4 files changed, 186 insertions(+), 355 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e6a4f68fb3..ecf8cd4065 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2514,33 +2514,7 @@ std::vector HloInstruction::OperandIndices( } bool HloInstruction::IsElementwiseBinary() const { - switch (opcode_) { - // Binary elementwise operations. If you update this, please update - // IsElementwise() accordingly. - case HloOpcode::kAdd: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - return true; - default: - return false; - } + return IsElementwise() && operand_count() == 2; } bool HloInstruction::IsElementwise() const { @@ -2551,7 +2525,6 @@ bool HloInstruction::IsElementwise() const { // Unary elementwise operations. case HloOpcode::kAbs: - case HloOpcode::kAtan2: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: @@ -2569,11 +2542,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); return true; // Binary elementwise operations, the same as in IsElementwiseBinary(). - // If you update this, please update IsElementwiseBinary() accordingly. case HloOpcode::kAdd: + case HloOpcode::kAtan2: case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: @@ -2593,6 +2567,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); return true; // Ternary elementwise operations. diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 157d19f5a9..d1eaf35785 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -21,243 +21,22 @@ limitations under the License. namespace xla { string HloOpcodeString(HloOpcode opcode) { - // Note: Do not use ':' in opcode strings. It is used as a special character - // in these places: - // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to - // separate the opcode from the fusion kind - // - In fully qualified names (HloInstruction::FullyQualifiedName()), to - // separate the qualifiers (name of the computation and potentially the - // fusion instruction) from the name switch (opcode) { - case HloOpcode::kAbs: - return "abs"; - case HloOpcode::kAdd: - return "add"; - case HloOpcode::kAnd: - return "and"; - case HloOpcode::kAtan2: - return "atan2"; - case HloOpcode::kBatchNormTraining: - return "batch-norm-training"; - case HloOpcode::kBatchNormInference: - return "batch-norm-inference"; - case HloOpcode::kBatchNormGrad: - return "batch-norm-grad"; - case HloOpcode::kBitcast: - return "bitcast"; - case HloOpcode::kBroadcast: - return "broadcast"; - case HloOpcode::kCall: - return "call"; - case HloOpcode::kClamp: - return "clamp"; - case HloOpcode::kComplex: - return "complex"; - case HloOpcode::kConcatenate: - return "concatenate"; - case HloOpcode::kConstant: - return "constant"; - case HloOpcode::kConvert: - return "convert"; - case HloOpcode::kConvolution: - return "convolution"; - case HloOpcode::kCos: - return "cosine"; - case HloOpcode::kCrossReplicaSum: - return "cross-replica-sum"; - case HloOpcode::kCustomCall: - return "custom-call"; - case HloOpcode::kCopy: - return "copy"; - case HloOpcode::kDivide: - return "divide"; - case HloOpcode::kDot: - return "dot"; - case HloOpcode::kDynamicSlice: - return "dynamic-slice"; - case HloOpcode::kDynamicUpdateSlice: - return "dynamic-update-slice"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kExp: - return "exponential"; - case HloOpcode::kFloor: - return "floor"; - case HloOpcode::kCeil: - return "ceil"; - case HloOpcode::kFusion: - return "fusion"; - case HloOpcode::kGe: - return "greater-than-or-equal-to"; - case HloOpcode::kGetTupleElement: - return "get-tuple-element"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kImag: - return "imag"; - case HloOpcode::kInfeed: - return "infeed"; - case HloOpcode::kIsFinite: - return "is-finite"; - case HloOpcode::kLe: - return "less-than-or-equal-to"; - case HloOpcode::kLog: - return "log"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kMap: - return "map"; - case HloOpcode::kMaximum: - return "maximum"; - case HloOpcode::kMinimum: - return "minimum"; - case HloOpcode::kMultiply: - return "multiply"; - case HloOpcode::kNe: - return "not-equal-to"; - case HloOpcode::kNegate: - return "negate"; - case HloOpcode::kNot: - return "not"; - case HloOpcode::kOr: - return "or"; - case HloOpcode::kOutfeed: - return "outfeed"; - case HloOpcode::kPad: - return "pad"; - case HloOpcode::kParameter: - return "parameter"; - case HloOpcode::kPower: - return "power"; - case HloOpcode::kReal: - return "real"; - case HloOpcode::kRecv: - return "recv"; - case HloOpcode::kReduce: - return "reduce"; - case HloOpcode::kReducePrecision: - return "reduce-precision"; - case HloOpcode::kReduceWindow: - return "reduce-window"; - case HloOpcode::kRemainder: - return "remainder"; - case HloOpcode::kReshape: - return "reshape"; - case HloOpcode::kReverse: - return "reverse"; - case HloOpcode::kRng: - return "rng"; - case HloOpcode::kRoundNearestAfz: - return "round-nearest-afz"; - case HloOpcode::kSelectAndScatter: - return "select-and-scatter"; - case HloOpcode::kSelect: - return "select"; - case HloOpcode::kSend: - return "send"; - case HloOpcode::kShiftLeft: - return "shift-left"; - case HloOpcode::kShiftRightArithmetic: - return "shift-right-arithmetic"; - case HloOpcode::kShiftRightLogical: - return "shift-right-logical"; - case HloOpcode::kSign: - return "sign"; - case HloOpcode::kSin: - return "sine"; - case HloOpcode::kSlice: - return "slice"; - case HloOpcode::kSort: - return "sort"; - case HloOpcode::kSubtract: - return "subtract"; - case HloOpcode::kTanh: - return "tanh"; - case HloOpcode::kTrace: - return "trace"; - case HloOpcode::kTranspose: - return "transpose"; - case HloOpcode::kTuple: - return "tuple"; - case HloOpcode::kWhile: - return "while"; +#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \ + case HloOpcode::enum_name: \ + return opcode_name; + HLO_OPCODE_LIST(CASE_OPCODE_STRING) +#undef CASE_OPCODE_STRING } } StatusOr StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap( - {{"abs", HloOpcode::kAbs}, - {"add", HloOpcode::kAdd}, - {"and", HloOpcode::kAnd}, - {"batch-norm-training", HloOpcode::kBatchNormTraining}, - {"batch-norm-inference", HloOpcode::kBatchNormInference}, - {"batch-norm-grad", HloOpcode::kBatchNormGrad}, - {"bitcast", HloOpcode::kBitcast}, - {"broadcast", HloOpcode::kBroadcast}, - {"call", HloOpcode::kCall}, - {"clamp", HloOpcode::kClamp}, - {"concatenate", HloOpcode::kConcatenate}, - {"constant", HloOpcode::kConstant}, - {"convert", HloOpcode::kConvert}, - {"convolution", HloOpcode::kConvolution}, - {"cosine", HloOpcode::kCos}, - {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, - {"custom-call", HloOpcode::kCustomCall}, - {"copy", HloOpcode::kCopy}, - {"divide", HloOpcode::kDivide}, - {"dot", HloOpcode::kDot}, - {"dynamic-slice", HloOpcode::kDynamicSlice}, - {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, - {"equal-to", HloOpcode::kEq}, - {"exponential", HloOpcode::kExp}, - {"floor", HloOpcode::kFloor}, - {"ceil", HloOpcode::kCeil}, - {"fusion", HloOpcode::kFusion}, - {"greater-than-or-equal-to", HloOpcode::kGe}, - {"get-tuple-element", HloOpcode::kGetTupleElement}, - {"greater-than", HloOpcode::kGt}, - {"infeed", HloOpcode::kInfeed}, - {"is-finite", HloOpcode::kIsFinite}, - {"less-than-or-equal-to", HloOpcode::kLe}, - {"log", HloOpcode::kLog}, - {"less-than", HloOpcode::kLt}, - {"map", HloOpcode::kMap}, - {"maximum", HloOpcode::kMaximum}, - {"minimum", HloOpcode::kMinimum}, - {"multiply", HloOpcode::kMultiply}, - {"not", HloOpcode::kNot}, - {"not-equal-to", HloOpcode::kNe}, - {"negate", HloOpcode::kNegate}, - {"or", HloOpcode::kOr}, - {"outfeed", HloOpcode::kOutfeed}, - {"pad", HloOpcode::kPad}, - {"parameter", HloOpcode::kParameter}, - {"power", HloOpcode::kPower}, - {"recv", HloOpcode::kRecv}, - {"reduce", HloOpcode::kReduce}, - {"reduce-precision", HloOpcode::kReducePrecision}, - {"reduce-window", HloOpcode::kReduceWindow}, - {"remainder", HloOpcode::kRemainder}, - {"reshape", HloOpcode::kReshape}, - {"reverse", HloOpcode::kReverse}, - {"rng", HloOpcode::kRng}, - {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, - {"select-and-scatter", HloOpcode::kSelectAndScatter}, - {"select", HloOpcode::kSelect}, - {"send", HloOpcode::kSend}, - {"shift-left", HloOpcode::kShiftLeft}, - {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, - {"shift-right-logical", HloOpcode::kShiftRightLogical}, - {"sign", HloOpcode::kSign}, - {"sine", HloOpcode::kSin}, - {"slice", HloOpcode::kSlice}, - {"sort", HloOpcode::kSort}, - {"subtract", HloOpcode::kSubtract}, - {"tanh", HloOpcode::kTanh}, - {"trace", HloOpcode::kTrace}, - {"transpose", HloOpcode::kTranspose}, - {"tuple", HloOpcode::kTuple}, - {"while", HloOpcode::kWhile}}); + static auto* opcode_map = new tensorflow::gtl::FlatMap({ +#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ + {opcode_name, HloOpcode::enum_name}, + HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) +#undef STRING_TO_OPCODE_ENTRY + }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); @@ -265,31 +44,36 @@ StatusOr StringToHloOpcode(const string& opcode_name) { return it->second; } +#define CHECK_DEFAULT(property_name, opcode_name) false +#define CHECK_PROPERTY(property_name, opcode_name, value) \ + (value & property_name) +#define RESOLVE(_1, _2, target, ...) target +#define HAS_PROPERTY(property, ...) \ + RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kEq: - case HloOpcode::kNe: - return true; - default: - return false; +#define CASE_IS_COMPARISON(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_COMPARISON) +#undef CASE_IS_COMPARISON } } bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kCall: - case HloOpcode::kConcatenate: - case HloOpcode::kFusion: - case HloOpcode::kMap: - case HloOpcode::kTuple: - return true; - default: - return false; +#define CASE_IS_VARIADIC(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_VARIADIC) +#undef CASE_IS_VARIADIC } } +#undef HAS_PROPERTY +#undef RESOLVE +#undef CHECK_DEFAULT +#undef CHECK_PROPERTY + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 07c2d26f00..d68fc20321 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -28,83 +28,112 @@ namespace xla { // present in the XLA service protobuf. // // See the XLA documentation for the semantics of each opcode. +// +// Each entry has the format: +// (enum_name, opcode_name) +// or +// (enum_name, opcode_name, p1 | p2 | ...) +// +// with p1, p2, ... are members of HloOpcodeProperty. They are combined +// using bitwise-or. +// +// Note: Do not use ':' in opcode names. It is used as a special character +// in these places: +// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to +// separate the opcode from the fusion kind +// - In fully qualified names (HloInstruction::FullyQualifiedName()), to +// separate the qualifiers (name of the computation and potentially the +// fusion instruction) from the name +#define HLO_OPCODE_LIST(V) \ + V(kAbs, "abs") \ + V(kAdd, "add") \ + V(kAtan2, "atan2") \ + V(kBatchNormGrad, "batch-norm-grad") \ + V(kBatchNormInference, "batch-norm-inference") \ + V(kBatchNormTraining, "batch-norm-training") \ + V(kBitcast, "bitcast") \ + V(kBroadcast, "broadcast") \ + V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCeil, "ceil") \ + V(kClamp, "clamp") \ + V(kComplex, "complex") \ + V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConstant, "constant") \ + V(kConvert, "convert") \ + V(kConvolution, "convolution") \ + V(kCopy, "copy") \ + V(kCos, "cosine") \ + V(kCrossReplicaSum, "cross-replica-sum") \ + V(kCustomCall, "custom-call") \ + V(kDivide, "divide") \ + V(kDot, "dot") \ + V(kDynamicSlice, "dynamic-slice") \ + V(kDynamicUpdateSlice, "dynamic-update-slice") \ + V(kEq, "equal-to", kHloOpcodeIsComparison) \ + V(kExp, "exponential") \ + V(kFloor, "floor") \ + V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetTupleElement, "get-tuple-element") \ + V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kImag, "imag") \ + V(kInfeed, "infeed") \ + V(kIsFinite, "is-finite") \ + V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kLog, "log") \ + V(kAnd, "and") \ + V(kNot, "not") \ + V(kOr, "or") \ + V(kLt, "less-than", kHloOpcodeIsComparison) \ + V(kMap, "map", kHloOpcodeIsVariadic) \ + V(kMaximum, "maximum") \ + V(kMinimum, "minimum") \ + V(kMultiply, "multiply") \ + V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ + V(kNegate, "negate") \ + V(kOutfeed, "outfeed") \ + V(kPad, "pad") \ + V(kParameter, "parameter") \ + V(kPower, "power") \ + V(kReal, "real") \ + V(kRecv, "recv") \ + V(kReduce, "reduce") \ + V(kReducePrecision, "reduce-precision") \ + V(kReduceWindow, "reduce-window") \ + V(kRemainder, "remainder") \ + V(kReshape, "reshape") \ + V(kReverse, "reverse") \ + V(kRng, "rng") \ + V(kRoundNearestAfz, "round-nearest-afz") \ + V(kSelect, "select") \ + V(kSelectAndScatter, "select-and-scatter") \ + V(kSend, "send") \ + V(kShiftLeft, "shift-left") \ + V(kShiftRightArithmetic, "shift-right-arithmetic") \ + V(kShiftRightLogical, "shift-right-logical") \ + V(kSign, "sign") \ + V(kSin, "sine") \ + V(kSlice, "slice") \ + V(kSort, "sort") \ + V(kSubtract, "subtract") \ + V(kTanh, "tanh") \ + V(kTrace, "trace") \ + V(kTranspose, "transpose") \ + V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kWhile, "while") + enum class HloOpcode { - kAbs, - kAdd, - kAtan2, - kBatchNormGrad, - kBatchNormInference, - kBatchNormTraining, - kBitcast, - kBroadcast, - kCall, - kCeil, - kClamp, - kComplex, - kConcatenate, - kConstant, - kConvert, - kConvolution, - kCopy, - kCos, - kCrossReplicaSum, - kCustomCall, - kDivide, - kDot, - kDynamicSlice, - kDynamicUpdateSlice, - kEq, - kExp, - kFloor, - kFusion, - kGe, - kGetTupleElement, - kGt, - kImag, - kInfeed, - kIsFinite, - kLe, - kLog, - kAnd, - kNot, - kOr, - kLt, - kMap, - kMaximum, - kMinimum, - kMultiply, - kNe, - kNegate, - kOutfeed, - kPad, - kParameter, - kPower, - kReal, - kRecv, - kReduce, - kReducePrecision, - kReduceWindow, - kRemainder, - kReshape, - kReverse, - kRng, - kRoundNearestAfz, - kSelect, - kSelectAndScatter, - kSend, - kShiftLeft, - kShiftRightArithmetic, - kShiftRightLogical, - kSign, - kSin, - kSlice, - kSort, - kSubtract, - kTanh, - kTrace, - kTranspose, - kTuple, - kWhile, +#define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, + HLO_OPCODE_LIST(DECLARE_ENUM) +#undef DECLARE_ENUM +}; + +// List of properties associated with opcodes. +// Properties are defined as increasing powers of two, so that we can use +// bitwise-or to combine properties, and bitwise-and to test for them. +enum HloOpcodeProperty { + kHloOpcodeIsComparison = 1 << 0, + kHloOpcodeIsVariadic = 1 << 1, }; // Returns a string representation of the opcode. @@ -125,7 +154,9 @@ bool HloOpcodeIsVariadic(HloOpcode opcode); // Returns the number of HloOpcode values. inline const uint32_t HloOpcodeCount() { - return static_cast(HloOpcode::kWhile) + 1; +#define HLO_COUNT_ONE(...) +1 +#define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) + return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 892c89f9df..cd2ce5c69f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -26,5 +26,46 @@ TEST(HloOpcodeTest, StringifyMultiply) { ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply)); } +TEST(HloOpcodeTest, OpcodeProperties) { + // Test counting macro. +#define SOME_LIST(X) \ + X(One) \ + X(Two) \ + X(Three) + EXPECT_EQ(3, HLO_XLIST_LENGTH(SOME_LIST)); +#undef SOME_LIST + + for (int i = 0; i < HloOpcodeCount(); ++i) { + auto opcode = static_cast(i); + // Test round-trip conversion to and from string. + EXPECT_EQ(opcode, StringToHloOpcode(HloOpcodeString(opcode)).ValueOrDie()); + + // Test some properties. + switch (opcode) { + case HloOpcode::kEq: + case HloOpcode::kNe: + case HloOpcode::kGt: + case HloOpcode::kLt: + case HloOpcode::kGe: + case HloOpcode::kLe: + EXPECT_TRUE(HloOpcodeIsComparison(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsComparison(opcode)); + } + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsVariadic(opcode)); + } + } +} + } // namespace } // namespace xla -- cgit v1.2.3