diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_graph.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 385 |
1 files changed, 243 insertions, 142 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 896968647e..21ec8b0b30 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -20,6 +20,7 @@ limitations under the License. #include <map> #include <set> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -285,11 +287,10 @@ tensorflow::Status GetEngineInfo( const std::unordered_map<string, tensorflow::Node*>& node_map, const std::vector<tensorflow::Node*>& reverse_topo_order, EngineInfo* info) { - std::vector<int> subgraph_node_ids; + std::vector<int> subgraph_node_ids; // Topologically sorted node ids. + std::set<string> subgraph_node_names = segment_nodes; std::set<int> added_const_node_ids; // Used to prevent double insertion. std::set<string> segment_devices; - int input_port = 0; - int output_port = 0; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -297,13 +298,12 @@ tensorflow::Status GetEngineInfo( // input/output edges must be in different split of the graph. // TODO(aaroey): consider using node id and port instead. // TODO(aaroey): using topo order instead of reverting reverse topo order. - std::unordered_map<string, int> created_edges; + std::unordered_map<string, int> input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = node_map.at(node_name); + auto node = *it; auto node_device = node->requested_device(); if (!node_device.empty()) { segment_devices.insert(node_device); @@ -316,64 +316,93 @@ tensorflow::Status GetEngineInfo( } } const int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + // Create input connections. for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (segment_nodes.count(input_node->name()) == 0 && - !edge->IsControlEdge() && !input_node->IsSource()) { - // Add constant input node into the segment. We don't care if it has - // other output edges going into other engines or TF nodes. Since we add - // it only to the subsegment node list, not the subsegment itself, it - // won't be removed from the graph. If it doesn't have any edges, TF - // will prune it out. - if (input_node->type_string() == "Const") { - if (added_const_node_ids.count(input_node->id()) == 0) { - added_const_node_ids.insert(input_node->id()); - subgraph_node_ids.push_back(input_node->id()); - } + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // + // Note that the segmenter already ensure that the constant data input + // is valid and suppported by the engine. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; + } + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); + // Since we already add (duplicate) the const input node to the segment + // graphdef, it's now not a data dependency any more, but to make the + // dependency correct we still add a control dependency. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); } else { - string s(input_node->name()); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Input edge = " << s; - int port = input_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - input_port++; - } - info->connections.emplace_back(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, port); + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); } + info->connections.emplace_back( + input_node->name(), input_node->id(), edge->src_output(), node_name, + node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // We need to add possible const input nodes before adding this node in - // order to keep the topological order. - subgraph_node_ids.push_back(node_id); + // Create output connections. for (const auto edge : node->out_edges()) { auto output_node = edge->dst(); - if (segment_nodes.count(output_node->name()) == 0 && - !edge->IsControlEdge() && !output_node->IsSink()) { - string s(node_name); - StrAppend(&s, ":", edge->src_output()); + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); VLOG(1) << "Output edge = " << s; - int port = output_port; - if (created_edges.count(s)) { - port = created_edges.at(s); + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); } else { - created_edges.insert({s, port}); - output_port++; + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); } - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, port); + info->connections.emplace_back( + output_node->name(), output_node->id(), edge->dst_input(), + node_name, node_id, edge->src_output(), /*input_edge=*/false, port); } } - } + } // For each segment node in topological order. + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_ids, &info->connections, - &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -383,94 +412,137 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - VLOG(1) << "Segment devices size is 0"; + LOG(ERROR) << "Can't find a device placement for the op!"; } return Status::OK(); } -// Function to insert a TRT node into the graph. The graph is not modified if -// the returned status is not ok. -// 'alloc' is only used for creating static engine. -tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, - const std::vector<EngineInfo>& infos, int pos, +// Helper function to update edge connection from the removed node to the +// engine node. If an outside node is gone, it must have been absorbed into +// an engine node. Find the engine node. +void UpdateToEngineNode(const std::vector<EngineInfo>& infos, + const size_t my_engine_id, + const std::vector<Node*>& engine_nodes, + const bool is_input_edge, const string& node_name, + tensorflow::Node** node, int* port) { + for (size_t t = 0; t < infos.size(); ++t) { + if (t == my_engine_id) { + continue; + } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == node_name && + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; + } + } + } + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; +} + +// Function to insert a TRT engine node into the graph. +// Create engine nodes in the following way: +// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] +// 2. When an engine node is created, add it into the graph with necessary +// re-wiring. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. +// 3. In this way, we ensure the graph is topologically sort-able after each +// invocation of CreateTRTNode(). +tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, nvinfer1::IGpuAllocator* alloc, - int max_batch_size) { + std::vector<Node*>* engine_nodes) { const auto& info = infos.at(pos); + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); std::vector<tensorflow::TensorShapeProto> output_shape_protos; std::vector<tensorflow::TensorShapeProto> input_shape_protos; std::vector<tensorflow::PartialTensorShape> input_shapes; std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs; + std::vector<tensorflow::Node*> input_nodes; + std::vector<tensorflow::Node*> control_input_nodes; + std::unordered_set<string> control_input_names; std::vector<tensorflow::DataType> out_types; - VLOG(1) << "Processing " << info.engine_name; - // Update the shape and data types of input/output nodes, and find all unique - // inputs. + VLOG(1) << "Processing " << info.engine_name; + // Collect needed info for creating the engine node in the graph for (const auto& conn : info.connections) { - if (!conn.is_input_edge) { - // Set the shapes and data types of output edge. - tensorflow::TensorShapeProto out_shape; - // shape of the output node inside segment - conn.inside_shape.AsProto(&out_shape); - if (output_shape_protos.size() <= conn.port_number) { - output_shape_protos.resize(conn.port_number + 1); - out_types.resize(conn.port_number + 1); + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for + // node creation and will be processed later. + if (!conn.is_input_edge) continue; + + // Rewrire control input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = tensorflow::Graph::kControlSlot; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); } - output_shape_protos.at(conn.port_number) = out_shape; - out_types.at(conn.port_number) = conn.connection_type; - continue; - } - - // Set the shapes and data types of input edge. - tensorflow::TensorShapeProto in_shape; - conn.outside_shape.AsProto(&in_shape); - if (input_shape_protos.size() <= conn.port_number) { - input_shape_protos.resize(conn.port_number + 1); - input_shapes.resize(conn.port_number + 1); - } - input_shape_protos.at(conn.port_number) = in_shape; - input_shapes.at(conn.port_number) = conn.outside_shape; - - string input_node = conn.outside_node_name; - int input_port = conn.outside_port; - bool found_engine = false; - // Rewire the inputs to other engines if they contain original input node. - // Note that we use the information of the engine here, not the information - // of the created TRT nodes, so we're able to find all the connections to - // any other engines beforehand. - for (size_t t = 0; t < infos.size(); ++t) { - if (t == pos) continue; - auto& engine_info = infos.at(t); - for (const auto& eng_conn : engine_info.connections) { - if (eng_conn.is_input_edge) continue; - if (eng_conn.inside_node_name == input_node) { - input_node = engine_info.engine_name; - if (eng_conn.inside_port == input_port) { - input_port = eng_conn.port_number; - found_engine = true; - break; - } - } + if (!control_input_names.insert(input_node->name()).second) { + continue; } - if (found_engine) break; - } - VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " - << info.engine_name << ":" << inputs.size(); - // Skip duplicate inputs. - // TODO(aaroey): use std::find instead. GetEngineInfo already remove - // duplicate connections, so here we should never find any duplicate? - bool new_input = true; - for (const auto& inp : inputs) { - if (inp.node == input_node && inp.index == input_port) { - new_input = false; - break; + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() << " -> " + << info.engine_name; + } else { + // Data edges + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + output_shape_protos.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + } else { + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); + } + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + + // Rewrire data input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + } + if (std::find_if( + std::begin(inputs), std::end(inputs), + [input_node, &port](const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port + << " -> " << info.engine_name << ":" << inputs.size() - 1; + } } } - if (new_input) { - inputs.emplace_back(input_node, input_port, conn.connection_type); - } } - - // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || info.precision_mode == INT8MODE) { @@ -517,6 +589,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << ins; } node_builder.Input(inputs); + for (const string& c : control_input_names) { + node_builder.ControlInput(c); + } + if (info.engine_type == EngineInfo::EngineType::TRTStatic && info.cached_engine_batches.size()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; @@ -545,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; return status; } + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { + VLOG(1) << "Connecting control edge from " << in->name() << " to " + << engine_node->name(); + graph->AddControlEdge(in, engine_node); + } + VLOG(1) << "input_nodes size = " << input_nodes.size(); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = CHECK_NOTNULL(input_nodes[i]); + const auto& in = inputs[i]; + VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); + } + // Updates the inputs of output edges destination nodes, and point them to the // engine node. for (auto& conn : info.connections) { - if (conn.is_input_edge) continue; - VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " - << conn.port_number << " out_id " << conn.outside_id - << " name=" << conn.outside_node_name; - auto dst_node = graph->FindNodeId(conn.outside_id); - // dst_node can only be removed if it is an input node of another engine. - // In this case, other engines input edge is updated in nodedef to point to - // this engine. Even though edge doesn't exists in the graph, when it is - // deserialized again, correct edges will be constructed. This is a problem - // of graph->AddNode(). - if (!dst_node) continue; + if (conn.is_input_edge) { + continue; + } + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); + } VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << dst_node->name() << ":" << conn.outside_port; - auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, - conn.outside_port); - CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" - << conn.port_number << " -> " << dst_node->name() << ":" - << conn.outside_port; + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); + } else { + auto new_edge = + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; + } } - return status; + return Status::OK(); } // Function to construct a funcdef from the segment and add it to the graph. @@ -794,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); } VLOG(1) << "Current cuda device is " << old_cuda_device; + std::vector<Node*> engine_nodes; + engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); // Partition the workspace size by the average of node ratio and segment @@ -817,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), - params.max_batch_size); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat("Engine ", engine.engine_name, + " creation for segment ", i, ", composed of ", + converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { + LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { graph.RemoveNode(node_map.at(node_name)); } } else { // Graph is not modified. - LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() - << " nodes failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); |