aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.cc')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc47
1 files changed, 28 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 008fffc954..b43f1b190f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph(
}
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
- VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
+ VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
- VLOG(2) << "... not a TRT candidate";
+ VLOG(3) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output
@@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph(
while (true) {
std::set<const SimpleEdge*> contract_edges;
for (const SimpleEdge* out_edge : node->out_edges()) {
- VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
+ VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
- VLOG(2) << "... ... Control Edge, Skipping";
+ VLOG(3) << "... ... Control Edge, Skipping";
continue;
}
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
- VLOG(2) << "... ... not a TRT candidate";
+ VLOG(3) << "... ... not a TRT candidate";
continue;
}
if (CanContractEdge(out_edge, graph)) {
- VLOG(2) << "... ... can contract";
+ VLOG(3) << "... ... can contract";
contract_edges.insert(out_edge);
} else {
- VLOG(2) << "... ... cannot contract, would form cycle";
+ VLOG(3) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
@@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph(
const SimpleNode* src = contract_edge->src();
const SimpleNode* dst = contract_edge->dst();
- VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
+ VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
@@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph(
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
- std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map;
+ std::map<string, std::set<const tensorflow::Node*>> sg_map;
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the device names that the nodes in the segment are
@@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph(
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
- // For simplicity we use heuristics: for input nodes remove all its
- // input, for output nodes remove all its output. In this way, for common
- // cases the number of removed nodes should be minimum.
+ // For simplicity we use heuristics: for input and const output nodes
+ // remove all their inputs, and for non-const output nodes remove all
+ // their outputs. In this way, for common cases the number of removed
+ // nodes should be minimum.
auto remove_nodes = [&segment_nodes](
bool is_input_nodes,
std::deque<const tensorflow::Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const tensorflow::Node*> visited;
+ std::set<const tensorflow::Node*> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
- for (auto in :
- is_input_nodes ? node->in_nodes() : node->out_nodes()) {
+ for (auto in : (is_input_nodes || node->type_string() == "Const")
+ ? node->in_nodes()
+ : node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
- VLOG(2) << "Need to remove node " << in->name()
- << " because one of its "
- << (is_input_nodes ? "output" : "input")
- << " nodes in the graph was removed: " << node->name();
+ if (VLOG_IS_ON(2)) {
+ if (!logged.count(in)) {
+ VLOG(2) << "----> Need to remove node " << in->name()
+ << " because one of its "
+ << (is_input_nodes ? "output" : "input")
+ << " nodes in the graph was removed: "
+ << node->name();
+ logged.insert(in);
+ }
+ }
}
}
}
@@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph(
for (const auto& itr : sg_map) {
const std::set<const tensorflow::Node*>& segment_nodes = itr.second;
if (VLOG_IS_ON(1)) {
- string s;
+ string s = "parent=" + itr.first + ":";
for (auto node : segment_nodes) s += " " + node->name();
VLOG(1) << "Segment " << segments->size() << ": " << s;
}