diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_graph.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 66 |
1 files changed, 39 insertions, 27 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b7b26cfb1c..da4dd5a14c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -91,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, if (!subgraph_node_ids.count(edge->src()->id()) && !edge->src()->IsSource() && !edge->IsControlEdge()) { incoming_edges->insert(edge); + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " Y, "; } else { - VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " N, "; } } } @@ -106,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, for (const tensorflow::Edge* edge : node->out_edges()) { if (!subgraph_node_ids.count(edge->dst()->id()) && !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " Y, "; outgoing_edges->insert(edge); } else { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " N, "; } } } @@ -181,29 +186,27 @@ struct ConvertGraphParams { static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_incoming_edges); + + std::set<std::pair<int, int>> unique_tensors; + // Add only unique input source nodes. If output of an outside node is shared + // between multiple nodes inside the engine, only one edge should be created for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); - } - auto output_name_to_index_map = BuildTensorNameMap(p->output_names); - std::set<std::pair<int, int>> subgraph_outputs_set; - // Collect outputs referenced from output_names - for (int node_id : p->subgraph_node_ids) { - tensorflow::Node* node = p->graph.FindNodeId(node_id); - if (output_name_to_index_map.count(node->name())) { - for (int index : output_name_to_index_map.at(node->name())) { - subgraph_outputs_set.insert({node_id, index}); - } - } + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } + p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), + unique_tensors.end()); GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_outgoing_edges); + unique_tensors.clear(); + // Similar to above, if multiple ouside nodes are sharing the output of an + // internal node only one output port should be created and shared between + // outputs for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } - p->subgraph_outputs.reserve(subgraph_outputs_set.size()); + p->subgraph_outputs.reserve(unique_tensors.size()); p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - subgraph_outputs_set.begin(), - subgraph_outputs_set.end()); + unique_tensors.begin(), unique_tensors.end()); return tensorflow::Status::OK(); } @@ -225,7 +228,6 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) { for (auto in_edge : params->subgraph_incoming_edges) { // loop over incoming edges and // attach them to calib node - // tensorflow::Node* src_node = in_edge->src(); auto src_output = in_edge->src_output(); auto dst_node = in_edge->dst(); auto dst_input = in_edge->dst_input(); @@ -257,19 +259,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); } + std::set<std::pair<int, int>> unique_tensors; for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()}; + if (unique_tensors.count(old_src)) continue; + unique_tensors.insert(old_src); int new_src_output = subgraph_edge_to_input_map.at(old_src); params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, new_src_output); + VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() + << " -> " << trt_node->name() << ":" << new_src_output; params->graph.RemoveEdge(edge); } - - VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "new edge count: " << trt_node->in_edges().size(); + for (const tensorflow::Edge* edge : trt_node->in_edges()) { + VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + } } - TF_RETURN_IF_ERROR(status); // Re-map outgoing edges to use the new TRT node instead of the orig subgraph @@ -283,6 +290,8 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { int new_src_output = subgraph_edge_to_output_map.at(old_src); TF_RETURN_IF_ERROR(params->graph.UpdateEdge( trt_node, new_src_output, edge->dst(), edge->dst_input())); + VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " + << edge->dst()->name() << ":" << edge->dst_input(); } // Remove the original subgraph for (int node_id : params->subgraph_node_ids) { @@ -317,9 +326,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph( tensorflow::GraphConstructorOptions(), graph_def, &graph)); // get calib nodes std::vector<tensorflow::Node*> calib_nodes; - for (auto node : graph.op_nodes()) { + std::vector<tensorflow::Node*> topo_order; + tensorflow::GetPostOrder(graph, &topo_order); + for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { + auto node = *rit; if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node"; + VLOG(1) << "Found Calib Node " << node->name(); calib_nodes.push_back(node); } } |