aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.cc')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc55
1 files changed, 37 insertions, 18 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 6193f0b0a1..8fc4697c51 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -80,13 +80,20 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
dst->in_edges().end());
for (const tensorflow::Edge* in_edge : in_edges) {
- if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
- if (e->src() == graph->source_node()) {
- graph->AddEdge(e->src(), e->src_output(), src,
- tensorflow::Graph::kControlSlot);
- } else {
- graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+ if (in_edge->IsControlEdge()) {
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ graph->AddControlEdge(e->src(), src);
+ }
+ } else {
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ if (e->src() == graph->source_node()) {
+ graph->AddEdge(e->src(), e->src_output(), src,
+ tensorflow::Graph::kControlSlot);
+ } else {
+ graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+ }
}
}
}
@@ -94,12 +101,19 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
dst->out_edges().end());
for (const tensorflow::Edge* out_edge : out_edges) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
- if (e->dst() == graph->sink_node()) {
- graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
- e->dst_input());
+ if (out_edge->IsControlEdge()) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ graph->AddControlEdge(src, e->dst());
} else {
- graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ if (e->dst() == graph->sink_node()) {
+ VLOG(1) << " edge to sink node " << src->name() << " -> "
+ << e->dst()->name();
+ graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+ e->dst_input());
+ } else {
+ graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+ }
}
}
@@ -118,7 +132,7 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
- const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments) {
// Create a Graph representation of the GraphDef.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
@@ -136,7 +150,7 @@ tensorflow::Status SegmentGraph(
for (int i = 0; i < graph.num_node_ids(); ++i) {
tensorflow::Node* node = graph.FindNodeId(i);
if (options.exclude_node_list.count(node->name()) != 0 ||
- !candidate_fn(node->def())) {
+ !candidate_fn(node)) {
node = nullptr;
}
node_segments.emplace_back(node);
@@ -155,7 +169,7 @@ tensorflow::Status SegmentGraph(
for (const tensorflow::Node* node : order) {
// All output nodes of 'node' have been visited...
- VLOG(2) << "Trying node " << node->name();
+ VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
@@ -169,8 +183,12 @@ tensorflow::Status SegmentGraph(
while (true) {
std::set<const tensorflow::Edge*> contract_edges;
for (const tensorflow::Edge* out_edge : node->out_edges()) {
- VLOG(2) << "... out node " << out_edge->dst()->name();
-
+ VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
+ << out_edge->dst()->id() << " <- " << node->id() << " )";
+ if (out_edge->IsControlEdge()) {
+ VLOG(2) << "... ... Control Edge, Skipping";
+ continue;
+ }
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
VLOG(2) << "... ... not a TRT candidate";
@@ -196,7 +214,8 @@ tensorflow::Status SegmentGraph(
const tensorflow::Node* src = contract_edge->src();
const tensorflow::Node* dst = contract_edge->dst();
- VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
+ VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
+ << src->id() << " <- " << dst->id();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
// Contracting the edge leaves disconnected graph edges.