aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_graph.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_graph.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc66
1 files changed, 39 insertions, 27 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b7b26cfb1c..da4dd5a14c 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -91,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
if (!subgraph_node_ids.count(edge->src()->id()) &&
!edge->src()->IsSource() && !edge->IsControlEdge()) {
incoming_edges->insert(edge);
+ VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
+ << " Y, ";
} else {
- VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, ";
+ VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
+ << " N, ";
}
}
}
@@ -106,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
for (const tensorflow::Edge* edge : node->out_edges()) {
if (!subgraph_node_ids.count(edge->dst()->id()) &&
!edge->dst()->IsSink() && !edge->IsControlEdge()) {
- VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, ";
+ VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
+ << " Y, ";
outgoing_edges->insert(edge);
} else {
- VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, ";
+ VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
+ << " N, ";
}
}
}
@@ -181,29 +186,27 @@ struct ConvertGraphParams {
static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids,
&p->subgraph_incoming_edges);
+
+ std::set<std::pair<int, int>> unique_tensors;
+ // Add only unique input source nodes. If output of an outside node is shared
+ // between multiple nodes inside the engine, only one edge should be created
for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) {
- p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
- }
- auto output_name_to_index_map = BuildTensorNameMap(p->output_names);
- std::set<std::pair<int, int>> subgraph_outputs_set;
- // Collect outputs referenced from output_names
- for (int node_id : p->subgraph_node_ids) {
- tensorflow::Node* node = p->graph.FindNodeId(node_id);
- if (output_name_to_index_map.count(node->name())) {
- for (int index : output_name_to_index_map.at(node->name())) {
- subgraph_outputs_set.insert({node_id, index});
- }
- }
+ unique_tensors.insert({edge->src()->id(), edge->src_output()});
}
+ p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(),
+ unique_tensors.end());
GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids,
&p->subgraph_outgoing_edges);
+ unique_tensors.clear();
+ // Similar to above, if multiple ouside nodes are sharing the output of an
+ // internal node only one output port should be created and shared between
+ // outputs
for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) {
- subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+ unique_tensors.insert({edge->src()->id(), edge->src_output()});
}
- p->subgraph_outputs.reserve(subgraph_outputs_set.size());
+ p->subgraph_outputs.reserve(unique_tensors.size());
p->subgraph_outputs.insert(p->subgraph_outputs.begin(),
- subgraph_outputs_set.begin(),
- subgraph_outputs_set.end());
+ unique_tensors.begin(), unique_tensors.end());
return tensorflow::Status::OK();
}
@@ -225,7 +228,6 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
for (auto in_edge :
params->subgraph_incoming_edges) { // loop over incoming edges and
// attach them to calib node
- // tensorflow::Node* src_node = in_edge->src();
auto src_output = in_edge->src_output();
auto dst_node = in_edge->dst();
auto dst_input = in_edge->dst_input();
@@ -257,19 +259,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
}
+ std::set<std::pair<int, int>> unique_tensors;
for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+ if (unique_tensors.count(old_src)) continue;
+ unique_tensors.insert(old_src);
int new_src_output = subgraph_edge_to_input_map.at(old_src);
params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
new_src_output);
+ VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output()
+ << " -> " << trt_node->name() << ":" << new_src_output;
params->graph.RemoveEdge(edge);
}
-
- VLOG(2) << "new wiring edges: " << trt_node->in_edges().size();
- for (const tensorflow::Edge* edge : trt_node->in_edges()) {
- VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "new edge count: " << trt_node->in_edges().size();
+ for (const tensorflow::Edge* edge : trt_node->in_edges()) {
+ VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
+ }
}
-
TF_RETURN_IF_ERROR(status);
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
@@ -283,6 +290,8 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
int new_src_output = subgraph_edge_to_output_map.at(old_src);
TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
trt_node, new_src_output, edge->dst(), edge->dst_input()));
+ VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> "
+ << edge->dst()->name() << ":" << edge->dst_input();
}
// Remove the original subgraph
for (int node_id : params->subgraph_node_ids) {
@@ -317,9 +326,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
tensorflow::GraphConstructorOptions(), graph_def, &graph));
// get calib nodes
std::vector<tensorflow::Node*> calib_nodes;
- for (auto node : graph.op_nodes()) {
+ std::vector<tensorflow::Node*> topo_order;
+ tensorflow::GetPostOrder(graph, &topo_order);
+ for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
+ auto node = *rit;
if (node->type_string() == "TRTCalibOp") {
- VLOG(1) << "Found Calib Node";
+ VLOG(1) << "Found Calib Node " << node->name();
calib_nodes.push_back(node);
}
}