diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-11 10:58:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-11 11:42:59 -0700 |
commit | 1f69227b32d9505439aa132667922091be5fca7d (patch) | |
tree | 37f4607b9dd8d7b241ca0d102a142f1d6ebdac9b /tensorflow/core/graph/graph.cc | |
parent | 8c28f1a2f97858ea949764588d62a02e1ca7a112 (diff) |
This change reduces the CPU time spent adding nodes to a graph. For an example large graph (13k nodes, 20k edges), this change reduces the CPU time spent loading the graph by 5%.
The existing code uses a long sequence of string comparisons and tests, whenever a node is added. The CHECK(class_ == NC_UNITIALIZED) statement can never actually test anything, because all of the string comparisons (except those against empty strings, which serve no purpose) test against a disjoint set of strings, so no collisions are possible.
PiperOrigin-RevId: 155768893
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 75 |
1 files changed, 40 insertions, 35 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a68a8f2509..dabc5a7849 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -30,6 +30,45 @@ const int Graph::kControlSlot = -1; // Node +#define REF_CLASS(key, value) \ + {key, value}, { "Ref" key, value } + +const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable = + *new std::unordered_map<string, Node::NodeClass>({ + // Keep in same order as NodeClass values + REF_CLASS("Switch", NC_SWITCH), + REF_CLASS("Merge", NC_MERGE), + REF_CLASS("Enter", NC_ENTER), + REF_CLASS("Exit", NC_EXIT), + REF_CLASS("NextIteration", NC_NEXT_ITERATION), + {"LoopCond", NC_LOOP_COND}, + {"ControlTrigger", NC_CONTROL_TRIGGER}, + {"_Send", NC_SEND}, + {"_HostSend", NC_HOST_SEND}, + {"_Recv", NC_RECV}, + {"_HostRecv", NC_HOST_RECV}, + {"Const", NC_CONSTANT}, + {"HostConst", NC_CONSTANT}, + {"Variable", NC_VARIABLE}, + {"VariableV2", NC_VARIABLE}, + REF_CLASS("Identity", NC_IDENTITY), + {"GetSessionHandle", NC_GET_SESSION_HANDLE}, + {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, + {"GetSessionTensor", NC_GET_SESSION_TENSOR}, + {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, + }); + +#undef REF_CLASS + +Node::NodeClass Node::GetNodeClassForOp(const string& ts) { + auto it = kNodeClassTable.find(ts); + if (it != kNodeClassTable.end()) { + return it->second; + } else { + return NC_OTHER; + } +} + string Node::DebugString() const { string ret = strings::StrCat("{name:'", name(), "' id:", id_); if (IsSource()) { @@ -70,41 +109,7 @@ void Node::Initialize(int id, int cost_id, Properties* props) { } props_ = props; // Initialize the class_ based on the type string - const string& ts = this->type_string(); - class_ = NC_UNINITIALIZED; - -#define SET_CLASS(enum_val, ts, str1, str2) \ - do { \ - if ((((ts) == (str1)) || ((ts) == (str2)))) { \ - /* Cannot be member of more than one class*/ \ - CHECK(class_ == NC_UNINITIALIZED); \ - class_ = (enum_val); \ - } \ - } while (0) - - SET_CLASS(NC_SWITCH, ts, "Switch", "RefSwitch"); - SET_CLASS(NC_MERGE, ts, "Merge", "RefMerge"); - SET_CLASS(NC_ENTER, ts, "Enter", "RefEnter"); - SET_CLASS(NC_EXIT, ts, "Exit", "RefExit"); - SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "RefNextIteration"); - SET_CLASS(NC_LOOP_COND, ts, "LoopCond", ""); - SET_CLASS(NC_CONTROL_TRIGGER, ts, "ControlTrigger", ""); - SET_CLASS(NC_SEND, ts, "_Send", ""); - SET_CLASS(NC_HOST_SEND, ts, "_HostSend", ""); - SET_CLASS(NC_RECV, ts, "_Recv", ""); - SET_CLASS(NC_HOST_RECV, ts, "_HostRecv", ""); - SET_CLASS(NC_CONSTANT, ts, "Const", "HostConst"); - SET_CLASS(NC_VARIABLE, ts, "Variable", ""); - SET_CLASS(NC_VARIABLE, ts, "VariableV2", ""); - SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity"); - SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandle", ""); - SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandleV2", ""); - SET_CLASS(NC_GET_SESSION_TENSOR, ts, "GetSessionTensor", ""); - SET_CLASS(NC_DELETE_SESSION_TENSOR, ts, "DeleteSessionTensor", ""); - if (class_ == NC_UNINITIALIZED) { - class_ = NC_OTHER; // Catch all - } -#undef SET_CLASS + class_ = GetNodeClassForOp(props->node_def_.op()); } void Node::Clear() { |