diff options
author | Mark Heffernan <meheff@google.com> | 2017-10-13 16:24:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 16:29:20 -0700 |
commit | 5dd569cf026bae92330a194c8f2895d0f48149d9 (patch) | |
tree | 96dbce8d2992fb1f14aa0a1265904eb57eaf2273 /tensorflow/compiler/xla/service/hlo_opcode.cc | |
parent | d426d3029727785676d1a7fbb7973a3a6ceb4842 (diff) |
Make the HLO proto representation (hlo.proto) full fidelity. Hlo modules can be serialized to HLO protos and deserialized without any information loss.
As part of this change, a bug is fixed in NameUniquer. Previously, passing names with numeric suffixes could result in name collisions.
PiperOrigin-RevId: 172161360
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_opcode.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_opcode.cc | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index e98012ec0c..db3abeab22 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -178,6 +180,89 @@ string HloOpcodeString(HloOpcode opcode) { } } +StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) { + static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>( + {{"abs", HloOpcode::kAbs}, + {"add", HloOpcode::kAdd}, + {"batch-norm-training", HloOpcode::kBatchNormTraining}, + {"batch-norm-inference", HloOpcode::kBatchNormInference}, + {"batch-norm-grad", HloOpcode::kBatchNormGrad}, + {"bitcast", HloOpcode::kBitcast}, + {"broadcast", HloOpcode::kBroadcast}, + {"call", HloOpcode::kCall}, + {"clamp", HloOpcode::kClamp}, + {"concatenate", HloOpcode::kConcatenate}, + {"constant", HloOpcode::kConstant}, + {"convert", HloOpcode::kConvert}, + {"convolution", HloOpcode::kConvolution}, + {"cosine", HloOpcode::kCos}, + {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, + {"custom-call", HloOpcode::kCustomCall}, + {"copy", HloOpcode::kCopy}, + {"divide", HloOpcode::kDivide}, + {"dot", HloOpcode::kDot}, + {"dynamic-slice", HloOpcode::kDynamicSlice}, + {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, + {"equal-to", HloOpcode::kEq}, + {"exponential", HloOpcode::kExp}, + {"floor", HloOpcode::kFloor}, + {"ceil", HloOpcode::kCeil}, + {"fusion", HloOpcode::kFusion}, + {"greater-than-or-equal-to", HloOpcode::kGe}, + {"get-tuple-element", HloOpcode::kGetTupleElement}, + {"greater-than", HloOpcode::kGt}, + {"index", HloOpcode::kIndex}, + {"infeed", HloOpcode::kInfeed}, + {"is-finite", HloOpcode::kIsFinite}, + {"less-than-or-equal-to", HloOpcode::kLe}, + {"log", HloOpcode::kLog}, + {"and", HloOpcode::kAnd}, + {"or", HloOpcode::kOr}, + {"not", HloOpcode::kNot}, + {"less-than", HloOpcode::kLt}, + {"map", HloOpcode::kMap}, + {"maximum", HloOpcode::kMaximum}, + {"minimum", HloOpcode::kMinimum}, + {"multiply", HloOpcode::kMultiply}, + {"not-equal-to", HloOpcode::kNe}, + {"negate", HloOpcode::kNegate}, + {"outfeed", HloOpcode::kOutfeed}, + {"pad", HloOpcode::kPad}, + {"parameter", HloOpcode::kParameter}, + {"power", HloOpcode::kPower}, + {"recv", HloOpcode::kRecv}, + {"reduce", HloOpcode::kReduce}, + {"reduce-precision", HloOpcode::kReducePrecision}, + {"reduce-window", HloOpcode::kReduceWindow}, + {"remainder", HloOpcode::kRemainder}, + {"reshape", HloOpcode::kReshape}, + {"reverse", HloOpcode::kReverse}, + {"rng", HloOpcode::kRng}, + {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, + {"select-and-scatter", HloOpcode::kSelectAndScatter}, + {"select", HloOpcode::kSelect}, + {"send", HloOpcode::kSend}, + {"shift-left", HloOpcode::kShiftLeft}, + {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, + {"shift-right-logical", HloOpcode::kShiftRightLogical}, + {"sign", HloOpcode::kSign}, + {"sine", HloOpcode::kSin}, + {"slice", HloOpcode::kSlice}, + {"sort", HloOpcode::kSort}, + {"subtract", HloOpcode::kSubtract}, + {"tanh", HloOpcode::kTanh}, + {"trace", HloOpcode::kTrace}, + {"transpose", HloOpcode::kTranspose}, + {"tuple", HloOpcode::kTuple}, + {"update", HloOpcode::kUpdate}, + {"while", HloOpcode::kWhile}}); + auto it = opcode_map->find(opcode_name); + if (it == opcode_map->end()) { + return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + } + return it->second; +} + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { case HloOpcode::kGe: |