aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-11 10:58:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 11:42:59 -0700
commit1f69227b32d9505439aa132667922091be5fca7d (patch)
tree37f4607b9dd8d7b241ca0d102a142f1d6ebdac9b /tensorflow/core/graph/graph.cc
parent8c28f1a2f97858ea949764588d62a02e1ca7a112 (diff)
This change reduces the CPU time spent adding nodes to a graph. For an example large graph (13k nodes, 20k edges), this change reduces the CPU time spent loading the graph by 5%.
The existing code uses a long sequence of string comparisons and tests, whenever a node is added. The CHECK(class_ == NC_UNITIALIZED) statement can never actually test anything, because all of the string comparisons (except those against empty strings, which serve no purpose) test against a disjoint set of strings, so no collisions are possible. PiperOrigin-RevId: 155768893
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc75
1 files changed, 40 insertions, 35 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index a68a8f2509..dabc5a7849 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -30,6 +30,45 @@ const int Graph::kControlSlot = -1;
// Node
+#define REF_CLASS(key, value) \
+ {key, value}, { "Ref" key, value }
+
+const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
+ *new std::unordered_map<string, Node::NodeClass>({
+ // Keep in same order as NodeClass values
+ REF_CLASS("Switch", NC_SWITCH),
+ REF_CLASS("Merge", NC_MERGE),
+ REF_CLASS("Enter", NC_ENTER),
+ REF_CLASS("Exit", NC_EXIT),
+ REF_CLASS("NextIteration", NC_NEXT_ITERATION),
+ {"LoopCond", NC_LOOP_COND},
+ {"ControlTrigger", NC_CONTROL_TRIGGER},
+ {"_Send", NC_SEND},
+ {"_HostSend", NC_HOST_SEND},
+ {"_Recv", NC_RECV},
+ {"_HostRecv", NC_HOST_RECV},
+ {"Const", NC_CONSTANT},
+ {"HostConst", NC_CONSTANT},
+ {"Variable", NC_VARIABLE},
+ {"VariableV2", NC_VARIABLE},
+ REF_CLASS("Identity", NC_IDENTITY),
+ {"GetSessionHandle", NC_GET_SESSION_HANDLE},
+ {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
+ {"GetSessionTensor", NC_GET_SESSION_TENSOR},
+ {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
+ });
+
+#undef REF_CLASS
+
+Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
+ auto it = kNodeClassTable.find(ts);
+ if (it != kNodeClassTable.end()) {
+ return it->second;
+ } else {
+ return NC_OTHER;
+ }
+}
+
string Node::DebugString() const {
string ret = strings::StrCat("{name:'", name(), "' id:", id_);
if (IsSource()) {
@@ -70,41 +109,7 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
}
props_ = props;
// Initialize the class_ based on the type string
- const string& ts = this->type_string();
- class_ = NC_UNINITIALIZED;
-
-#define SET_CLASS(enum_val, ts, str1, str2) \
- do { \
- if ((((ts) == (str1)) || ((ts) == (str2)))) { \
- /* Cannot be member of more than one class*/ \
- CHECK(class_ == NC_UNINITIALIZED); \
- class_ = (enum_val); \
- } \
- } while (0)
-
- SET_CLASS(NC_SWITCH, ts, "Switch", "RefSwitch");
- SET_CLASS(NC_MERGE, ts, "Merge", "RefMerge");
- SET_CLASS(NC_ENTER, ts, "Enter", "RefEnter");
- SET_CLASS(NC_EXIT, ts, "Exit", "RefExit");
- SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "RefNextIteration");
- SET_CLASS(NC_LOOP_COND, ts, "LoopCond", "");
- SET_CLASS(NC_CONTROL_TRIGGER, ts, "ControlTrigger", "");
- SET_CLASS(NC_SEND, ts, "_Send", "");
- SET_CLASS(NC_HOST_SEND, ts, "_HostSend", "");
- SET_CLASS(NC_RECV, ts, "_Recv", "");
- SET_CLASS(NC_HOST_RECV, ts, "_HostRecv", "");
- SET_CLASS(NC_CONSTANT, ts, "Const", "HostConst");
- SET_CLASS(NC_VARIABLE, ts, "Variable", "");
- SET_CLASS(NC_VARIABLE, ts, "VariableV2", "");
- SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity");
- SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandle", "");
- SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandleV2", "");
- SET_CLASS(NC_GET_SESSION_TENSOR, ts, "GetSessionTensor", "");
- SET_CLASS(NC_DELETE_SESSION_TENSOR, ts, "DeleteSessionTensor", "");
- if (class_ == NC_UNINITIALIZED) {
- class_ = NC_OTHER; // Catch all
- }
-#undef SET_CLASS
+ class_ = GetNodeClassForOp(props->node_def_.op());
}
void Node::Clear() {