diff options
author | 2018-03-28 16:52:39 -0700 | |
---|---|---|
committer | 2018-03-28 16:55:15 -0700 | |
commit | 108178da2a20ea2d3899417ee932d46ba1a5c652 (patch) | |
tree | 313bd8cec176f8c9ef67b25c6484a650d1f2092a /tensorflow/contrib/tensorrt/segment | |
parent | 390e19ab990f5656e09d98624c92b3c80e52937d (diff) |
Automated g4 rollback of changelist 190835392
PiperOrigin-RevId: 190858242
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment')
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.cc | 55 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment_test.cc | 8 |
3 files changed, 23 insertions, 44 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 8fc4697c51..6193f0b0a1 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -80,20 +80,13 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(), dst->in_edges().end()); for (const tensorflow::Edge* in_edge : in_edges) { - if (in_edge->IsControlEdge()) { - if (in_edge->src() != src) { - tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge); - graph->AddControlEdge(e->src(), src); - } - } else { - if (in_edge->src() != src) { - tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge); - if (e->src() == graph->source_node()) { - graph->AddEdge(e->src(), e->src_output(), src, - tensorflow::Graph::kControlSlot); - } else { - graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); - } + if (in_edge->src() != src) { + tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge); + if (e->src() == graph->source_node()) { + graph->AddEdge(e->src(), e->src_output(), src, + tensorflow::Graph::kControlSlot); + } else { + graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); } } } @@ -101,19 +94,12 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(), dst->out_edges().end()); for (const tensorflow::Edge* out_edge : out_edges) { - if (out_edge->IsControlEdge()) { - tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge); - graph->AddControlEdge(src, e->dst()); + tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge); + if (e->dst() == graph->sink_node()) { + graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), + e->dst_input()); } else { - tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge); - if (e->dst() == graph->sink_node()) { - VLOG(1) << " edge to sink node " << src->name() << " -> " - << e->dst()->name(); - graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), - e->dst_input()); - } else { - graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); - } + graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); } } @@ -132,7 +118,7 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, - const std::function<bool(const tensorflow::Node*)>& candidate_fn, + const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { // Create a Graph representation of the GraphDef. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), @@ -150,7 +136,7 @@ tensorflow::Status SegmentGraph( for (int i = 0; i < graph.num_node_ids(); ++i) { tensorflow::Node* node = graph.FindNodeId(i); if (options.exclude_node_list.count(node->name()) != 0 || - !candidate_fn(node)) { + !candidate_fn(node->def())) { node = nullptr; } node_segments.emplace_back(node); @@ -169,7 +155,7 @@ tensorflow::Status SegmentGraph( for (const tensorflow::Node* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); + VLOG(2) << "Trying node " << node->name(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { @@ -183,12 +169,8 @@ tensorflow::Status SegmentGraph( while (true) { std::set<const tensorflow::Edge*> contract_edges; for (const tensorflow::Edge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " - << out_edge->dst()->id() << " <- " << node->id() << " )"; - if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; - continue; - } + VLOG(2) << "... out node " << out_edge->dst()->name(); + // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { VLOG(2) << "... ... not a TRT candidate"; @@ -214,8 +196,7 @@ tensorflow::Status SegmentGraph( const tensorflow::Node* src = contract_edge->src(); const tensorflow::Node* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" - << src->id() << " <- " << dst->id(); + VLOG(2) << "Merge " << src->name() << " <- " << dst->name(); node_segments[src->id()].Merge(&node_segments[dst->id()]); // Contracting the edge leaves disconnected graph edges. diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 7e8685f44a..ee6e2b3ed2 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -20,12 +20,10 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { - namespace tensorrt { namespace segment { @@ -48,7 +46,7 @@ struct SegmentOptions { // @return the status. tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, - const std::function<bool(const tensorflow::Node*)>& candidate_fn, + const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 7ddabec268..74cbc5f2b3 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -35,7 +35,7 @@ class SegmentTest : public ::testing::Test { TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name); - std::function<bool(const Node*)> MakeCandidateFn( + std::function<bool(const NodeDef&)> MakeCandidateFn( const std::set<string>& node_names); protected: @@ -60,10 +60,10 @@ bool SegmentTest::GetGraphDef(TF_Graph* graph, return ret; } -std::function<bool(const Node*)> SegmentTest::MakeCandidateFn( +std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn( const std::set<string>& node_names) { - return [node_names](const Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); + return [node_names](const NodeDef& node) -> bool { + return node_names.find(node.name()) != node_names.end(); }; } |