aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.cc')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc55
1 files changed, 18 insertions, 37 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 8fc4697c51..6193f0b0a1 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -80,20 +80,13 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
dst->in_edges().end());
for (const tensorflow::Edge* in_edge : in_edges) {
- if (in_edge->IsControlEdge()) {
- if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
- graph->AddControlEdge(e->src(), src);
- }
- } else {
- if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
- if (e->src() == graph->source_node()) {
- graph->AddEdge(e->src(), e->src_output(), src,
- tensorflow::Graph::kControlSlot);
- } else {
- graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
- }
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ if (e->src() == graph->source_node()) {
+ graph->AddEdge(e->src(), e->src_output(), src,
+ tensorflow::Graph::kControlSlot);
+ } else {
+ graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
}
}
}
@@ -101,19 +94,12 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
dst->out_edges().end());
for (const tensorflow::Edge* out_edge : out_edges) {
- if (out_edge->IsControlEdge()) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
- graph->AddControlEdge(src, e->dst());
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ if (e->dst() == graph->sink_node()) {
+ graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+ e->dst_input());
} else {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
- if (e->dst() == graph->sink_node()) {
- VLOG(1) << " edge to sink node " << src->name() << " -> "
- << e->dst()->name();
- graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
- e->dst_input());
- } else {
- graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
- }
+ graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
}
}
@@ -132,7 +118,7 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
- const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments) {
// Create a Graph representation of the GraphDef.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
@@ -150,7 +136,7 @@ tensorflow::Status SegmentGraph(
for (int i = 0; i < graph.num_node_ids(); ++i) {
tensorflow::Node* node = graph.FindNodeId(i);
if (options.exclude_node_list.count(node->name()) != 0 ||
- !candidate_fn(node)) {
+ !candidate_fn(node->def())) {
node = nullptr;
}
node_segments.emplace_back(node);
@@ -169,7 +155,7 @@ tensorflow::Status SegmentGraph(
for (const tensorflow::Node* node : order) {
// All output nodes of 'node' have been visited...
- VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
+ VLOG(2) << "Trying node " << node->name();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
@@ -183,12 +169,8 @@ tensorflow::Status SegmentGraph(
while (true) {
std::set<const tensorflow::Edge*> contract_edges;
for (const tensorflow::Edge* out_edge : node->out_edges()) {
- VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
- << out_edge->dst()->id() << " <- " << node->id() << " )";
- if (out_edge->IsControlEdge()) {
- VLOG(2) << "... ... Control Edge, Skipping";
- continue;
- }
+ VLOG(2) << "... out node " << out_edge->dst()->name();
+
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
VLOG(2) << "... ... not a TRT candidate";
@@ -214,8 +196,7 @@ tensorflow::Status SegmentGraph(
const tensorflow::Node* src = contract_edge->src();
const tensorflow::Node* dst = contract_edge->dst();
- VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
- << src->id() << " <- " << dst->id();
+ VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
// Contracting the edge leaves disconnected graph edges.