diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.cc | 47 |
1 files changed, 28 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 008fffc954..b43f1b190f 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph( } for (const SimpleNode* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); + VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { - VLOG(2) << "... not a TRT candidate"; + VLOG(3) << "... not a TRT candidate"; continue; } // Contract output edges to combine 'node' with output @@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph( while (true) { std::set<const SimpleEdge*> contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; + VLOG(3) << "... ... Control Edge, Skipping"; continue; } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { - VLOG(2) << "... ... not a TRT candidate"; + VLOG(3) << "... ... not a TRT candidate"; continue; } if (CanContractEdge(out_edge, graph)) { - VLOG(2) << "... ... can contract"; + VLOG(3) << "... ... can contract"; contract_edges.insert(out_edge); } else { - VLOG(2) << "... ... cannot contract, would form cycle"; + VLOG(3) << "... ... cannot contract, would form cycle"; } } if (contract_edges.empty()) { @@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph( const SimpleNode* src = contract_edge->src(); const SimpleNode* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); @@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map; + std::map<string, std::set<const tensorflow::Node*>> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph( // then after doing this operation the resulting subgraph will keep the // same properties 1 and 2. // - // For simplicity we use heuristics: for input nodes remove all its - // input, for output nodes remove all its output. In this way, for common - // cases the number of removed nodes should be minimum. + // For simplicity we use heuristics: for input and const output nodes + // remove all their inputs, and for non-const output nodes remove all + // their outputs. In this way, for common cases the number of removed + // nodes should be minimum. auto remove_nodes = [&segment_nodes]( bool is_input_nodes, std::deque<const tensorflow::Node*>* que) { // Run a BFS on the queue to find all the input/output nodes. std::set<const tensorflow::Node*> visited; + std::set<const tensorflow::Node*> logged(que->begin(), que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); if (!visited.insert(node).second) continue; segment_nodes.erase(node); - for (auto in : - is_input_nodes ? node->in_nodes() : node->out_nodes()) { + for (auto in : (is_input_nodes || node->type_string() == "Const") + ? node->in_nodes() + : node->out_nodes()) { if (segment_nodes.count(in)) { que->push_back(in); - VLOG(2) << "Need to remove node " << in->name() - << " because one of its " - << (is_input_nodes ? "output" : "input") - << " nodes in the graph was removed: " << node->name(); + if (VLOG_IS_ON(2)) { + if (!logged.count(in)) { + VLOG(2) << "----> Need to remove node " << in->name() + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " + << node->name(); + logged.insert(in); + } + } } } } @@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph( for (const auto& itr : sg_map) { const std::set<const tensorflow::Node*>& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; + string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } |