diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 12:36:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 12:36:54 -0700 |
commit | 401a8d4c6562b9ab591068e4e5418d8967543ef6 (patch) | |
tree | f9e3ee2b3a41e769eb11b8f10ddf40592bf6dd7e /tensorflow/contrib/tensorrt | |
parent | 69b4ee4c78c2551d027fd39e81fc71e1a6698f31 (diff) | |
parent | 26d52994cd3bf16b765799494b1f1c1070231b8c (diff) |
Merge pull request #21138 from samikama:WiringUpdate
PiperOrigin-RevId: 207306967
Diffstat (limited to 'tensorflow/contrib/tensorrt')
27 files changed, 1083 insertions, 328 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index d69d44a454..fc0d22d112 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -85,11 +85,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":test_utils", ":trt_allocator", + ":trt_conversion", ":trt_logging", ":trt_plugins", ":trt_resources", - ":trt_conversion", ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -194,6 +195,7 @@ tf_py_wrap_cc( "//tensorflow/python:platform/base.i", ], deps = [ + ":test_utils", ":trt_conversion", ":trt_engine_op_kernel", "//third_party/python_runtime:headers", @@ -266,6 +268,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":test_utils", ":trt_allocator", ":trt_plugins", ":trt_logging", @@ -417,3 +420,13 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "test_utils", + srcs = ["test/utils.cc"], + hdrs = ["test/utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 896968647e..21ec8b0b30 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -20,6 +20,7 @@ limitations under the License. #include <map> #include <set> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -285,11 +287,10 @@ tensorflow::Status GetEngineInfo( const std::unordered_map<string, tensorflow::Node*>& node_map, const std::vector<tensorflow::Node*>& reverse_topo_order, EngineInfo* info) { - std::vector<int> subgraph_node_ids; + std::vector<int> subgraph_node_ids; // Topologically sorted node ids. + std::set<string> subgraph_node_names = segment_nodes; std::set<int> added_const_node_ids; // Used to prevent double insertion. std::set<string> segment_devices; - int input_port = 0; - int output_port = 0; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -297,13 +298,12 @@ tensorflow::Status GetEngineInfo( // input/output edges must be in different split of the graph. // TODO(aaroey): consider using node id and port instead. // TODO(aaroey): using topo order instead of reverting reverse topo order. - std::unordered_map<string, int> created_edges; + std::unordered_map<string, int> input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = node_map.at(node_name); + auto node = *it; auto node_device = node->requested_device(); if (!node_device.empty()) { segment_devices.insert(node_device); @@ -316,64 +316,93 @@ tensorflow::Status GetEngineInfo( } } const int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + // Create input connections. for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (segment_nodes.count(input_node->name()) == 0 && - !edge->IsControlEdge() && !input_node->IsSource()) { - // Add constant input node into the segment. We don't care if it has - // other output edges going into other engines or TF nodes. Since we add - // it only to the subsegment node list, not the subsegment itself, it - // won't be removed from the graph. If it doesn't have any edges, TF - // will prune it out. - if (input_node->type_string() == "Const") { - if (added_const_node_ids.count(input_node->id()) == 0) { - added_const_node_ids.insert(input_node->id()); - subgraph_node_ids.push_back(input_node->id()); - } + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // + // Note that the segmenter already ensure that the constant data input + // is valid and suppported by the engine. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; + } + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); + // Since we already add (duplicate) the const input node to the segment + // graphdef, it's now not a data dependency any more, but to make the + // dependency correct we still add a control dependency. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); } else { - string s(input_node->name()); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Input edge = " << s; - int port = input_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - input_port++; - } - info->connections.emplace_back(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, port); + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); } + info->connections.emplace_back( + input_node->name(), input_node->id(), edge->src_output(), node_name, + node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // We need to add possible const input nodes before adding this node in - // order to keep the topological order. - subgraph_node_ids.push_back(node_id); + // Create output connections. for (const auto edge : node->out_edges()) { auto output_node = edge->dst(); - if (segment_nodes.count(output_node->name()) == 0 && - !edge->IsControlEdge() && !output_node->IsSink()) { - string s(node_name); - StrAppend(&s, ":", edge->src_output()); + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); VLOG(1) << "Output edge = " << s; - int port = output_port; - if (created_edges.count(s)) { - port = created_edges.at(s); + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); } else { - created_edges.insert({s, port}); - output_port++; + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); } - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, port); + info->connections.emplace_back( + output_node->name(), output_node->id(), edge->dst_input(), + node_name, node_id, edge->src_output(), /*input_edge=*/false, port); } } - } + } // For each segment node in topological order. + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_ids, &info->connections, - &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -383,94 +412,137 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - VLOG(1) << "Segment devices size is 0"; + LOG(ERROR) << "Can't find a device placement for the op!"; } return Status::OK(); } -// Function to insert a TRT node into the graph. The graph is not modified if -// the returned status is not ok. -// 'alloc' is only used for creating static engine. -tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, - const std::vector<EngineInfo>& infos, int pos, +// Helper function to update edge connection from the removed node to the +// engine node. If an outside node is gone, it must have been absorbed into +// an engine node. Find the engine node. +void UpdateToEngineNode(const std::vector<EngineInfo>& infos, + const size_t my_engine_id, + const std::vector<Node*>& engine_nodes, + const bool is_input_edge, const string& node_name, + tensorflow::Node** node, int* port) { + for (size_t t = 0; t < infos.size(); ++t) { + if (t == my_engine_id) { + continue; + } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == node_name && + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; + } + } + } + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; +} + +// Function to insert a TRT engine node into the graph. +// Create engine nodes in the following way: +// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] +// 2. When an engine node is created, add it into the graph with necessary +// re-wiring. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. +// 3. In this way, we ensure the graph is topologically sort-able after each +// invocation of CreateTRTNode(). +tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, nvinfer1::IGpuAllocator* alloc, - int max_batch_size) { + std::vector<Node*>* engine_nodes) { const auto& info = infos.at(pos); + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); std::vector<tensorflow::TensorShapeProto> output_shape_protos; std::vector<tensorflow::TensorShapeProto> input_shape_protos; std::vector<tensorflow::PartialTensorShape> input_shapes; std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs; + std::vector<tensorflow::Node*> input_nodes; + std::vector<tensorflow::Node*> control_input_nodes; + std::unordered_set<string> control_input_names; std::vector<tensorflow::DataType> out_types; - VLOG(1) << "Processing " << info.engine_name; - // Update the shape and data types of input/output nodes, and find all unique - // inputs. + VLOG(1) << "Processing " << info.engine_name; + // Collect needed info for creating the engine node in the graph for (const auto& conn : info.connections) { - if (!conn.is_input_edge) { - // Set the shapes and data types of output edge. - tensorflow::TensorShapeProto out_shape; - // shape of the output node inside segment - conn.inside_shape.AsProto(&out_shape); - if (output_shape_protos.size() <= conn.port_number) { - output_shape_protos.resize(conn.port_number + 1); - out_types.resize(conn.port_number + 1); + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for + // node creation and will be processed later. + if (!conn.is_input_edge) continue; + + // Rewrire control input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = tensorflow::Graph::kControlSlot; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); } - output_shape_protos.at(conn.port_number) = out_shape; - out_types.at(conn.port_number) = conn.connection_type; - continue; - } - - // Set the shapes and data types of input edge. - tensorflow::TensorShapeProto in_shape; - conn.outside_shape.AsProto(&in_shape); - if (input_shape_protos.size() <= conn.port_number) { - input_shape_protos.resize(conn.port_number + 1); - input_shapes.resize(conn.port_number + 1); - } - input_shape_protos.at(conn.port_number) = in_shape; - input_shapes.at(conn.port_number) = conn.outside_shape; - - string input_node = conn.outside_node_name; - int input_port = conn.outside_port; - bool found_engine = false; - // Rewire the inputs to other engines if they contain original input node. - // Note that we use the information of the engine here, not the information - // of the created TRT nodes, so we're able to find all the connections to - // any other engines beforehand. - for (size_t t = 0; t < infos.size(); ++t) { - if (t == pos) continue; - auto& engine_info = infos.at(t); - for (const auto& eng_conn : engine_info.connections) { - if (eng_conn.is_input_edge) continue; - if (eng_conn.inside_node_name == input_node) { - input_node = engine_info.engine_name; - if (eng_conn.inside_port == input_port) { - input_port = eng_conn.port_number; - found_engine = true; - break; - } - } + if (!control_input_names.insert(input_node->name()).second) { + continue; } - if (found_engine) break; - } - VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " - << info.engine_name << ":" << inputs.size(); - // Skip duplicate inputs. - // TODO(aaroey): use std::find instead. GetEngineInfo already remove - // duplicate connections, so here we should never find any duplicate? - bool new_input = true; - for (const auto& inp : inputs) { - if (inp.node == input_node && inp.index == input_port) { - new_input = false; - break; + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() << " -> " + << info.engine_name; + } else { + // Data edges + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + output_shape_protos.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + } else { + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); + } + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + + // Rewrire data input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + } + if (std::find_if( + std::begin(inputs), std::end(inputs), + [input_node, &port](const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port + << " -> " << info.engine_name << ":" << inputs.size() - 1; + } } } - if (new_input) { - inputs.emplace_back(input_node, input_port, conn.connection_type); - } } - - // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || info.precision_mode == INT8MODE) { @@ -517,6 +589,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << ins; } node_builder.Input(inputs); + for (const string& c : control_input_names) { + node_builder.ControlInput(c); + } + if (info.engine_type == EngineInfo::EngineType::TRTStatic && info.cached_engine_batches.size()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; @@ -545,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; return status; } + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { + VLOG(1) << "Connecting control edge from " << in->name() << " to " + << engine_node->name(); + graph->AddControlEdge(in, engine_node); + } + VLOG(1) << "input_nodes size = " << input_nodes.size(); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = CHECK_NOTNULL(input_nodes[i]); + const auto& in = inputs[i]; + VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); + } + // Updates the inputs of output edges destination nodes, and point them to the // engine node. for (auto& conn : info.connections) { - if (conn.is_input_edge) continue; - VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " - << conn.port_number << " out_id " << conn.outside_id - << " name=" << conn.outside_node_name; - auto dst_node = graph->FindNodeId(conn.outside_id); - // dst_node can only be removed if it is an input node of another engine. - // In this case, other engines input edge is updated in nodedef to point to - // this engine. Even though edge doesn't exists in the graph, when it is - // deserialized again, correct edges will be constructed. This is a problem - // of graph->AddNode(). - if (!dst_node) continue; + if (conn.is_input_edge) { + continue; + } + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); + } VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << dst_node->name() << ":" << conn.outside_port; - auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, - conn.outside_port); - CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" - << conn.port_number << " -> " << dst_node->name() << ":" - << conn.outside_port; + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); + } else { + auto new_edge = + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; + } } - return status; + return Status::OK(); } // Function to construct a funcdef from the segment and add it to the graph. @@ -794,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); } VLOG(1) << "Current cuda device is " << old_cuda_device; + std::vector<Node*> engine_nodes; + engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); // Partition the workspace size by the average of node ratio and segment @@ -817,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), - params.max_batch_size); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat("Engine ", engine.engine_name, + " creation for segment ", i, ", composed of ", + converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { + LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { graph.RemoveNode(node_map.at(node_name)); } } else { // Graph is not modified. - LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() - << " nodes failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 451d6fe698..35fa590254 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -22,6 +22,7 @@ limitations under the License. #include <memory> #include <set> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> @@ -2690,7 +2691,7 @@ tensorflow::Status ConvertGraphDefToEngine( // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); - VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && (node_def.op() == "Placeholder")) { nvinfer1::DimsCHW input_dim_pseudo_chw; @@ -2788,6 +2789,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set<string>& subgraph_node_names, const std::vector<int>& subgraph_node_ids, // In topological order std::vector<EngineConnection>* connections, tensorflow::GraphDef* segment_def, string* common_scope) { @@ -2796,6 +2798,7 @@ tensorflow::Status ConvertSegmentToGraphDef( // nodes in the segment graphdef. for (size_t i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); + if (connection.is_control_edge()) continue; auto outside_node = graph->FindNodeId(connection.outside_id); if (!outside_node) { // This should never happen, unless the original graph is problematic. @@ -2809,13 +2812,13 @@ tensorflow::Status ConvertSegmentToGraphDef( GetInputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); - + connection.outside_shape = partial_shape; } else { GetOutputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); + connection.inside_shape = partial_shape; } - connection.outside_shape = partial_shape; connection.connection_type = dtype; // Add dummy input/output nodes to the segment graphdef. @@ -2868,12 +2871,12 @@ tensorflow::Status ConvertSegmentToGraphDef( old_to_new_id_map[node_id] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); - VLOG(1) << "Copying " << snode->name() << " to subgraph"; + VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. for (int i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); - if (!connection.is_input_edge) continue; + if (connection.is_control_edge() || !connection.is_input_edge) continue; auto snode = segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); const string placeholder_name = @@ -2883,6 +2886,39 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + // Remove control inputs that are not inside the segment. + for (int i = 0; i < segment_def->node_size(); ++i) { + auto snode = segment_def->mutable_node(i); + const int input_size = snode->input_size(); + int input_idx = 0; + int actual_input_idx = 0; + while (input_idx < input_size) { + TensorId input = ParseTensorName(snode->input(input_idx)); + if (!subgraph_node_names.count( + string(input.first.data(), input.first.size())) && + !str_util::StartsWith(input.first, kInputPHName)) { + if (input.second == Graph::kControlSlot) { + VLOG(1) << "... removing control inputs " << input.first + << " from subgraph."; + ++input_idx; + continue; + } else { + return tensorflow::errors::InvalidArgument( + "Found non control input outside the segment that is not an " + "engine connection to ", + snode->name(), ": ", input.first); + } + } + if (actual_input_idx != input_idx) { + snode->set_input(actual_input_idx, snode->input(input_idx)); + } + ++input_idx; + ++actual_input_idx; + } + for (int remove = input_size - actual_input_idx; remove > 0; --remove) { + snode->mutable_input()->RemoveLast(); + } + } *common_scope = local_scope; VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); @@ -2897,12 +2933,12 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { nvinfer1::DataType trt_dtype; Status status = ValidateInputProperties(shape, dtype, &trt_dtype); if (!status.ok()) { - VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() << ": " << status; return false; } if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { - VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() << " which has an input at port " << in_edge->dst_input() << " with #dim<3 and is not a const: " << shape; return false; @@ -2913,7 +2949,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { - VLOG(2) << "--> Need to remove output node " << out_edge->src()->name() + VLOG(1) << "--> Need to remove output node " << out_edge->src()->name() << " which is a Const."; return false; } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 6a63c9f82f..a60253740f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -36,11 +36,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -static const char* kInputPHName = "InputPH_"; -static const char* kOutputPHName = "OutputPH_"; +static const char* kInputPHName = "TensorRTInputPH_"; +static const char* kOutputPHName = "TensorRTOutputPH_"; namespace convert { struct EngineConnection { + // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, const string& inside, int in_id, int in_port, bool input_edge, int port) @@ -53,21 +54,35 @@ struct EngineConnection { is_input_edge(input_edge), port_number(port) {} + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { return port_number == Graph::kControlSlot; } + const string outside_node_name; const int outside_id; const int outside_port; - tensorflow::PartialTensorShape outside_shape; + tensorflow::PartialTensorShape outside_shape; // Only set for input edge. const string inside_node_name; const int inside_id; const int inside_port; - tensorflow::PartialTensorShape inside_shape; + tensorflow::PartialTensorShape inside_shape; // Only set for output edge. tensorflow::DataType connection_type; - bool is_input_edge; + const bool is_input_edge; - // The port number of the TRT node connecting to this edge. - int port_number; + // The port number of the TRT node connected with this edge. + const int port_number; }; struct EngineInfo { @@ -80,7 +95,9 @@ struct EngineInfo { string device; tensorflow::GraphDef segment_graph_def; - // The segment nodes that are on one side of the edges are topological sorted. + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. std::vector<EngineConnection> connections; enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; @@ -96,6 +113,7 @@ struct EngineInfo { // (OutputPH_*). This function needs to be called before TensorRT nodes // inserted in order to correctly get sizes from the original graph. // +// - subgraph_node_names: the node names of the subgraph. // - subgraph_node_ids: the node ids of the subgraph, must be sorted in // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be @@ -105,6 +123,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set<string>& subgraph_node_names, const std::vector<int>& subgraph_node_ids, std::vector<EngineConnection>* connections, tensorflow::GraphDef* segment_def, string* common_scope); diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 044c736c03..f33f2cc4d6 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stacktrace.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -189,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::Cluster* cluster, const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) { VLOG(1) << "Called TRTOptimization Pass " << name_; - if (VLOG_IS_ON(1)) { - PrintDebugInfo(cluster, item); - } // This is a hack to workaround optimizer issue. MetaOptimizer calls // optimization passes on function objects as well, we should not modify // generated funcdefs! This is fragile but we don't have any other option @@ -203,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize( *optimized_graph = item.graph; return tensorflow::Status::OK(); } + if (VLOG_IS_ON(1)) { + VLOG(2) << CurrentStackTrace(); + PrintDebugInfo(cluster, item); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 6851f79ef6..2b42d81f47 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -173,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, helper](const tensorflow::Status& s) { + [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { @@ -183,6 +184,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); delete outputs; }); } @@ -228,6 +231,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } @@ -252,7 +256,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, ", current entries="); for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, "Requested batch=", num_batch); + StrAppend(&msg, " requested batch=", num_batch); LOG(WARNING) << msg; return -1; } @@ -270,7 +274,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } const int smallest_engine = GetEngineBatch(ctx); if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment"; + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -280,14 +285,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, auto& trt_engine_ptr = engine_ctx_pair.first; if (!trt_engine_ptr) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch - << " failed. Running native segment"; + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), engine_ctx_pair.second.get()); if (retry) { - LOG(WARNING) << "Failed to execute engine, retrying with native segment"; + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -406,6 +412,7 @@ bool TRTEngineOp::ExecuteTrtEngine( LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; } + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); // Synchronization will be done by TF. return !kRetry; } diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 59b744e6d3..8fe0675891 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -35,7 +35,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class TRTInt8Calibrator; +struct TRTInt8Calibrator; class TRTCalibrationResource; class AsyncHelper; // TODO(Sami): Remove this file? diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index fe4fa166a1..7cdfe2b1a6 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,7 +20,11 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph +from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index c696a8b1f0..4116f2fe30 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,9 +20,13 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values +from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 008fffc954..b43f1b190f 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph( } for (const SimpleNode* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); + VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { - VLOG(2) << "... not a TRT candidate"; + VLOG(3) << "... not a TRT candidate"; continue; } // Contract output edges to combine 'node' with output @@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph( while (true) { std::set<const SimpleEdge*> contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; + VLOG(3) << "... ... Control Edge, Skipping"; continue; } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { - VLOG(2) << "... ... not a TRT candidate"; + VLOG(3) << "... ... not a TRT candidate"; continue; } if (CanContractEdge(out_edge, graph)) { - VLOG(2) << "... ... can contract"; + VLOG(3) << "... ... can contract"; contract_edges.insert(out_edge); } else { - VLOG(2) << "... ... cannot contract, would form cycle"; + VLOG(3) << "... ... cannot contract, would form cycle"; } } if (contract_edges.empty()) { @@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph( const SimpleNode* src = contract_edge->src(); const SimpleNode* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); @@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map; + std::map<string, std::set<const tensorflow::Node*>> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph( // then after doing this operation the resulting subgraph will keep the // same properties 1 and 2. // - // For simplicity we use heuristics: for input nodes remove all its - // input, for output nodes remove all its output. In this way, for common - // cases the number of removed nodes should be minimum. + // For simplicity we use heuristics: for input and const output nodes + // remove all their inputs, and for non-const output nodes remove all + // their outputs. In this way, for common cases the number of removed + // nodes should be minimum. auto remove_nodes = [&segment_nodes]( bool is_input_nodes, std::deque<const tensorflow::Node*>* que) { // Run a BFS on the queue to find all the input/output nodes. std::set<const tensorflow::Node*> visited; + std::set<const tensorflow::Node*> logged(que->begin(), que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); if (!visited.insert(node).second) continue; segment_nodes.erase(node); - for (auto in : - is_input_nodes ? node->in_nodes() : node->out_nodes()) { + for (auto in : (is_input_nodes || node->type_string() == "Const") + ? node->in_nodes() + : node->out_nodes()) { if (segment_nodes.count(in)) { que->push_back(in); - VLOG(2) << "Need to remove node " << in->name() - << " because one of its " - << (is_input_nodes ? "output" : "input") - << " nodes in the graph was removed: " << node->name(); + if (VLOG_IS_ON(2)) { + if (!logged.count(in)) { + VLOG(2) << "----> Need to remove node " << in->name() + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " + << node->name(); + logged.insert(in); + } + } } } } @@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph( for (const auto& itr : sg_map) { const std::set<const tensorflow::Node*>& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; + string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 432e7b1c04..5937fa8259 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) { // Make add5 not a TRT candidate, and we expect two segments. auto without_add5 = all_adds - "add5"; RunTest(&g, without_add5, without_add5, without_add5, - {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}}); + {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); // Make add8 not a candidate and add6 not an input candidate, then all direct // and indirect inputs of add6 will be removed from the segment. @@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) { const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4", "add5", "add6", "add7"}; RunTest(&g, all_adds - "add2", all_adds, all_adds, - {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}}); + {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } } // namespace test diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index edd30ad7a9..8ea5a63735 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -20,17 +20,19 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment.""" @@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + expected_engines=["my_trt_op_0"], expected_output_dims=(100, 6, 6, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing multiple segment.""" @@ -95,32 +101,246 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - p = conv * c1 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - q = conv / c2 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + q = math_ops.div(conv, c2, name="div") - edge = self.trt_incompatible_op(q) - edge /= edge - r = edge + edge + edge = self.trt_incompatible_op(q, name="incompatible") + edge = math_ops.div(edge, edge, name="div1") + r = math_ops.add(edge, edge, name="add") - p -= edge - q *= edge - s = p + q - s -= r + p = math_ops.sub(p, edge, name="sub") + q = math_ops.mul(q, edge, name="mul1") + s = math_ops.add(p, q, name="add1") + s = math_ops.sub(s, r, name="sub1") array_ops.squeeze(s, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", + # "add", "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(100, 12, 12, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -# TODO(aaroey): add a large complex graph to test. +class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestA, self).setUp() + # Let it fail to build the second engine. + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + for i in range(2): + c = constant_op.constant(1.0, name="c%d" % i) + n = math_ops.add(n, c, name="add%d" % i) + n = math_ops.mul(n, n, name="mul%d" % i) + edge = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([edge]): + c = constant_op.constant(1.0, name="c2") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul2") + c = constant_op.constant(1.0, name="c3") + n = math_ops.add(n, c, name="add3") + n = math_ops.mul(n, n, name="mul3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class PartiallyConvertedTestB(PartiallyConvertedTestA): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestB, self).setUp() + # Let it fail to build the first engine. + trt_convert.clear_test_values("") + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + return super(PartiallyConvertedTestB, self).GetParams()._replace( + expected_engines={ + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + }) + + +class ConstInputTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + # Adds control dependency from the constant op to a trt incompatible op, + # and adds control dependency from the trt incompatible op to all other + # ops, to make sure the constant op cannot be contracted with any trt + # segment that depends on it. + with g.control_dependencies([c]): + d = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]}, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. + "my_trt_op_1": ["add", "add1", "mul"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + c1 = constant_op.constant(1.0, name="c1") + c2 = constant_op.constant(1.0, name="c2") + d1 = constant_op.constant(1.0, name="d1") + d2 = self.trt_incompatible_op(inp, name="d2") + with g.control_dependencies([d1, d2]): + add = math_ops.add(inp, c1, name="add") + with g.control_dependencies([d1, d2]): + mul = math_ops.mul(add, add, name="mul") + with g.control_dependencies([d1, d2]): + add1 = math_ops.add(mul, mul, name="add1") + edge = self.trt_incompatible_op(add1, name="incompatible") + with g.control_dependencies([d1, d2, add, mul]): + add2 = math_ops.add(edge, c2, name="add2") + with g.control_dependencies([d1, d2, add1, mul]): + mul1 = math_ops.mul(add2, add2, name="mul1") + with g.control_dependencies([d1, d2, add, add1]): + add3 = math_ops.add(mul1, mul1, name="add3") + array_ops.squeeze(add3, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 730b6843fb..2e1107e303 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, w1_name, w2_name], input_dims=[input_dims, w1_dims, w2_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(12, 5, 8, 7), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 0c03a10b64..8be32f59b4 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=7, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" + ], expected_output_dims=(48, 89), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index dd673463a5..9316b14da0 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=16, + expected_engines=[ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ], expected_output_dims=(5, 23040), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 8c51c45b0a..1874b9dd45 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(2, 126), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index 97b29bf05d..8c59000b70 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=['my_trt_op_0'], expected_output_dims=(5, 12, 12, 1), allclose_atol=1.e-02, allclose_rtol=1.e-02) diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index 3dd95c6f62..66eb6be757 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -62,7 +62,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(2, 15, 15, 10), allclose_atol=1.e-02, allclose_rtol=1.e-02) diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 734ccf6345..fd55b8cd99 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 50265c0845..51c905a50b 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): name="conv") b = constant_op.constant( np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) - t = conv * b - e = gen_math_ops.tan(conv) - t = t - e + t = math_ops.mul(conv, b, name="mul") + e = self.trt_incompatible_op(conv, name="incompatible") + t = math_ops.sub(t, e, name="sub") array_ops.squeeze(t, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines={ + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + }, expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index bb7f5a77f0..6f85ada464 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -20,6 +20,7 @@ from __future__ import print_function from collections import namedtuple import itertools +import os import warnings import numpy as np import six @@ -30,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -37,10 +39,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ - "gdef", "input_names", "input_dims", "num_expected_engines", + "gdef", "input_names", "input_dims", "expected_engines", "expected_output_dims", "allclose_atol", "allclose_rtol" ]) +RunParams = namedtuple( + "RunParams", + ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) + PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -48,6 +54,12 @@ def _IsQuantizationMode(mode): return mode == "INT8" +class GraphState(object): + ORIGINAL = 0 + CALIBRATE = 1 + INFERENCE = 2 + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -63,45 +75,90 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def precision_modes(self): return ["FP32", "FP16", "INT8"] + # str is bytes in py2, but unicode in py3. + def _ToUnicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return s + return s.decode("utf-8") + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + def _ToBytes(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: - return s.encode("utf-8") + if isinstance(s, str): + return s.encode("utf-8") + return s def _ToString(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: + if isinstance(s, str): + return s return s.decode("utf-8") + @classmethod + def setUpClass(cls): + """Setup method for the module.""" + super(TfTrtIntegrationTestBase, cls).setUpClass() + trt_convert.enable_test_value() + def setUp(self): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() warnings.simplefilter("always") + trt_convert.clear_test_values("") def GetParams(self): """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" raise NotImplementedError() - def _GetConfigProto(self, - params, - use_optimizer, - precision_mode=None, - is_dynamic_op=None): + def _PrepareRun(self, params, graph_state): + """Set up necessary testing environment before calling sess.run().""" + # Clear test values added by TRTEngineOp. + trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + + def _VerifyRun(self, params, graph_state): + """Verify the state after sess.run().""" + for engine_name in params.expected_engines: + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + + def _GetConfigProto(self, params, run_params, graph_state): """Get config proto based on specific settings.""" - if use_optimizer: + if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: rewriter_cfg = rewriter_config_pb2.RewriterConfig() rewriter_cfg.optimizers.extend(["constfold", "layout"]) custom_op = rewriter_cfg.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 + custom_op.parameter_map["minimum_segment_size"].i = 2 custom_op.parameter_map["max_batch_size"].i = max( [dims[0] for dims in params.input_dims]) - custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op + custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 custom_op.parameter_map["precision_mode"].s = self._ToBytes( - precision_mode) + run_params.precision_mode) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() @@ -115,7 +172,26 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): gpu_options=gpu_options, graph_options=graph_options) return config - def _RunGraph(self, params, gdef, input_data, config, num_runs=2): + def _ExpectTestValue(self, engine_name, method, expected_value): + label = "%s:%s" % (engine_name, method) + actual_value = trt_convert.get_test_value(label) + self.assertEqual( + expected_value, + actual_value, + msg="Unexpected test value with label %s. Actual: %s; expected: %s" % + (label, actual_value, expected_value)) + + def _ExpectCalibration(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteCalibration", value) + + def _ExpectTrtEngine(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value) + + def _ExpectNativeSegment(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) + + def _RunGraph(self, params, gdef, input_data, config, graph_state, + num_runs=2): """Run given graphdef multiple times.""" assert len(params.input_names) == len(input_data) g = ops.Graph() @@ -132,93 +208,170 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): val = None # Defaults to 2 runs to verify result across multiple runs is same. for _ in range(num_runs): + self._PrepareRun(params, graph_state) new_val = sess.run(out, {inp[i]: input_data[i] for i in range(len(inp))}) self.assertEqual(params.expected_output_dims, new_val.shape) if val is not None: self.assertAllEqual(val, new_val) val = new_val + self._VerifyRun(params, graph_state) return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. def _RunCalibration(self, params, gdef, input_data, config): """Run calibration on given graph.""" - return self._RunGraph(params, gdef, input_data, config, 30) + return self._RunGraph( + params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op): + def _GetTrtGraphDef(self, params, run_params, gdef): """Return trt converted graphdef.""" return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=[self.output_name], max_batch_size=max([dims[0] for dims in params.input_dims]), max_workspace_size_bytes=1 << 25, - precision_mode=precision_mode, + precision_mode=run_params.precision_mode, minimum_segment_size=2, - is_dynamic_op=is_dynamic_op) - - def _VerifyGraphDef(self, - params, - gdef, - precision_mode=None, - is_calibrated=None, - dynamic_engine=None): + is_dynamic_op=run_params.dynamic_engine) + + def _WriteGraph(self, params, run_params, gdef, graph_state): + if graph_state == GraphState.ORIGINAL: + label = "Original" + elif graph_state == GraphState.CALIBRATE: + label = "CalibEngine" + elif graph_state == GraphState.INFERENCE: + label = "InferEngine" + graph_name = ( + self.__class__.__name__ + "_" + run_params.test_name + "_" + label + + ".pbtxt") + temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir()) + logging.info("Writing graph to %s/%s", temp_dir, graph_name) + graph_io.write_graph(gdef, temp_dir, graph_name) + + def _VerifyConnections(self, params, converted_gdef): + old_to_new_node_map = { + self._ToString(node.name): self._ToString(node.name) + for node in params.gdef.node + } + for engine_name, node_names in params.expected_engines.items(): + for node_name in node_names: + old_to_new_node_map[node_name] = engine_name + name_to_node_map = { + self._ToString(node.name): node for node in params.gdef.node + } + + def _InputName(inp): + inp = self._ToString(inp) + prefix = "" + if inp[0] == "^": + prefix = "^" + inp = inp[1:] + parts = inp.split(":") + if len(parts) > 1 and parts[-1].isdigit(): + inp = inp[:-len(parts[-1]) - 1] + return (prefix, inp) + + expected_input_map = {} + for node in params.gdef.node: + name_str = self._ToString(node.name) + target_node_name = old_to_new_node_map[name_str] + is_engine_op = (target_node_name != name_str) + if target_node_name not in expected_input_map: + expected_input_map[target_node_name] = set() + input_set = expected_input_map[target_node_name] + for inp in node.input: + (prefix, inp_name) = _InputName(inp) + # Add the input only if it's outside the segment (note that it could be + # in a different engine). + if (not is_engine_op or + old_to_new_node_map[inp_name] != target_node_name): + if is_engine_op and name_to_node_map[inp_name].op == "Const": + # Const data input nodes to the segment has been copied to the + # segment graphdef and the engine, and the dependency has been + # converted to control dependendy. + input_set.add("^" + old_to_new_node_map[inp_name]) + else: + input_set.add(prefix + old_to_new_node_map[inp_name]) + + actual_input_map = {} + for node in converted_gdef.node: + name_str = self._ToString(node.name) + actual_input_map[name_str] = set() + input_set = actual_input_map[name_str] + for inp in node.input: + (prefix, node_name) = _InputName(inp) + input_set.add(prefix + node_name) + + self.assertEqual( + expected_input_map, + actual_input_map, + msg="expected:\n%s\nvs actual:\n%s" % (sorted( + expected_input_map.items()), sorted(actual_input_map.items()))) + + def _VerifyGraphDef(self, params, run_params, gdef, graph_state): + self._WriteGraph(params, run_params, gdef, graph_state) + num_engines = 0 - for n in gdef.node: - # TODO(jie): we should have coverage for failed conversion (TF fallback). - # where the conversion will fail and we shouldn't count this engine as the - # converted engines. - if n.op == "TRTEngineOp": + for node in gdef.node: + if node.op == "TRTEngineOp": num_engines += 1 - self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s) - self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s) + self.assertTrue(node.name in params.expected_engines) + self.assertTrue(len(node.attr["serialized_segment"].s)) + self.assertTrue(len(node.attr["segment_funcdef_name"].s)) self.assertEqual( - self._ToBytes(precision_mode), n.attr["precision_mode"].s) - self.assertEqual(not dynamic_engine, n.attr["static_engine"].b) - if _IsQuantizationMode(precision_mode) and is_calibrated: - self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s) + self._ToBytes(run_params.precision_mode), + node.attr["precision_mode"].s) + + is_dynamic_engine = not node.attr["static_engine"].b + self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) + + has_calibration_data = len(node.attr["calibration_data"].s) + if (_IsQuantizationMode(run_params.precision_mode) and + graph_state == GraphState.INFERENCE): + self.assertTrue(has_calibration_data) else: - self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s) - if precision_mode is None: # This means gdef is the original GraphDef. + self.assertFalse(has_calibration_data) + if graph_state == GraphState.ORIGINAL: self.assertEqual(0, num_engines) else: - self.assertEqual(num_engines, params.num_expected_engines) + self.assertEqual(num_engines, len(params.expected_engines)) + if isinstance(params.expected_engines, dict): + self._VerifyConnections(params, gdef) + # TODO(aaroey): consider verifying the corresponding TF function. - def RunTest(self, params, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine): - assert precision_mode in PRECISION_MODES + def RunTest(self, params, run_params): + assert run_params.precision_mode in PRECISION_MODES input_data = [np.random.random_sample(dims) for dims in params.input_dims] input_gdef = params.gdef - self._VerifyGraphDef(params, input_gdef) + self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) # Get reference result without running trt. - config_no_trt = self._GetConfigProto(params, False) + config_no_trt = self._GetConfigProto(params, run_params, + GraphState.ORIGINAL) logging.info("Running original graph w/o trt, config:\n%s", str(config_no_trt)) - ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt) + ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, + GraphState.ORIGINAL) # Run calibration if necessary. - if _IsQuantizationMode(precision_mode): + if _IsQuantizationMode(run_params.precision_mode): - calib_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_calib_engine) + calib_config = self._GetConfigProto(params, run_params, + GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) - if use_optimizer: - self.assertTrue(False) - # TODO(aaroey): uncomment this and get infer_gdef when this mode is - # supported. - # result = self._RunCalibration(params, input_gdef, input_data, - # calib_config) + if run_params.use_optimizer: + result = self._RunCalibration(params, input_gdef, input_data, + calib_config) else: - calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode, - dynamic_calib_engine) - self._VerifyGraphDef(params, calib_gdef, precision_mode, False, - dynamic_calib_engine) + calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) + self._VerifyGraphDef(params, run_params, calib_gdef, + GraphState.CALIBRATE) result = self._RunCalibration(params, calib_gdef, input_data, calib_config) - infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) - self._VerifyGraphDef(params, infer_gdef, precision_mode, True, - dynamic_calib_engine) + infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -229,18 +382,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): infer_gdef = input_gdef # Run inference. - infer_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_infer_engine) + infer_config = self._GetConfigProto(params, run_params, + GraphState.INFERENCE) logging.info("Running final inference graph, config:\n%s", str(infer_config)) - if use_optimizer: - result = self._RunGraph(params, infer_gdef, input_data, infer_config) + if run_params.use_optimizer: + result = self._RunGraph(params, infer_gdef, input_data, infer_config, + GraphState.INFERENCE) else: - trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode, - dynamic_infer_engine) - self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True, - dynamic_infer_engine) - result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config) + trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) + self._VerifyGraphDef(params, run_params, trt_infer_gdef, + GraphState.INFERENCE) + result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, + GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -263,66 +417,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _AddTests(test_class): """Adds test methods to TfTrtIntegrationTestBase.""" - def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine): + def _GetTest(run_params): """Gets a single test method based on the parameters.""" def _Test(self): params = self.GetParams() logging.info( - "Running test with parameters: use_optimizer=%s, precision_mode=%s, " - "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer, - precision_mode, dynamic_infer_engine, dynamic_calib_engine) - self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine) + "Running test %s with parameters: use_optimizer=%s, " + "precision_mode=%s, dynamic_engine=%s", + "testTfTrt_" + run_params.test_name, run_params.use_optimizer, + run_params.precision_mode, run_params.dynamic_engine) + self.RunTest(params, run_params) return _Test use_optimizer_options = [False, True] - dynamic_infer_engine_options = [False, True] - dynamic_calib_engine_options = [False, True] - for (use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine) in itertools.product( - use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options, - dynamic_calib_engine_options): + dynamic_engine_options = [False, True] + for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( + use_optimizer_options, PRECISION_MODES, dynamic_engine_options): if _IsQuantizationMode(precision_mode): - if not dynamic_calib_engine and dynamic_infer_engine: - # TODO(aaroey): test this case, the conversion from static calibration - # engine to dynamic inference engine should be a noop. - continue if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_calib_engine: + if not dynamic_engine: # TODO(aaroey): construction of static calibration engine is not # supported yet. continue - if dynamic_calib_engine and not dynamic_infer_engine: - # TODO(aaroey): construction of static inference engine using dynamic - # calibration engine is not supported yet. - continue - else: # In non int8 mode. - if dynamic_calib_engine: - # dynamic_calib_engine doesn't affect non-int8 modes, so just let - # related tests run once on dynamic_calib_engine=False. - continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - infer_engine_type = ("DynamicInferEngine" - if dynamic_infer_engine else "StaticInferEngine") - calib_engine_type = "" - if precision_mode == "INT8": - calib_engine_type = ("DynamicCalibEngine" - if dynamic_calib_engine else "StaticCalibEngine") - test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type, - ("_" + calib_engine_type) - if len(calib_engine_type) else "") - setattr( - test_class, "testTfTRT_" + test_name, - _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine)) + engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") + test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + run_params = RunParams( + use_optimizer=use_optimizer, + precision_mode=precision_mode, + dynamic_engine=dynamic_engine, + test_name=test_name) + setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) if trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index b9e977cf67..500057a36d 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, input2_name], input_dims=[input_dims, input2_dims], - num_expected_engines=5, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ], expected_output_dims=(12, 5, 8, 12), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc new file mode 100644 index 0000000000..276308b3a0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/test/utils.h" + +#include <unordered_map> +#include <vector> + +#include "re2/re2.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// TODO(aaroey): make this class thread-safe. +class TestValueManager { + public: + static TestValueManager* singleton() { + static TestValueManager* manager = new TestValueManager(); + return manager; + } + + void Enable() { + VLOG(1) << "Enabling test value"; + enabled_ = true; + } + + void Add(const string& label, const string& value) { + if (TF_PREDICT_FALSE(enabled_)) { + QCHECK_NE("", value); + VLOG(1) << "Adding test value: " << label << " -> " << value; + values_.insert({label, value}); + } + } + + string Get(const string& label) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Getting test value by " << label; + auto itr = values_.find(label); + if (itr == values_.end()) return ""; + return itr->second; + } + return ""; + } + + void Clear(const string& pattern) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Clearing test values"; + if (pattern.empty()) { + values_.clear(); + return; + } + std::vector<string> keys_to_clear; + for (const auto& kv : values_) { + if (RE2::FullMatch(kv.first, pattern)) { + keys_to_clear.push_back(kv.first); + } + } + for (const string& key : keys_to_clear) { + values_.erase(key); + } + } + } + + private: + TestValueManager() : enabled_(false) {} + + bool enabled_; + std::unordered_map<string, string> values_; +}; + +void EnableTestValue() { TestValueManager::singleton()->Enable(); } + +void ClearTestValues(const string& pattern) { + TestValueManager::singleton()->Clear(pattern); +} + +void AddTestValue(const string& label, const string& value) { + TestValueManager::singleton()->Add(label, value); +} + +string GetTestValue(const string& label) { + return TestValueManager::singleton()->Get(label); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h new file mode 100644 index 0000000000..4bb4120206 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// Helper methods to inject values used by testing tools. +void EnableTestValue(); +void ClearTestValues(const string& pattern); +void AddTestValue(const string& label, const string& value); +string GetTestValue(const string& label); + +#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ + do { \ + if (::tensorflow::tensorrt::test::GetTestValue(label) == \ + value_to_return) { \ + return errors::Internal("Injected manually"); \ + } \ + } while (0) + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index 2b134c3bce..ab4d224db4 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 6, 2, 2), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index bec2f23eff..56bdf848ea 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 2, 2, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 3b1a18f8ac..6ea15fb8ef 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -101,6 +101,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall @@ -109,6 +110,10 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; %unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ @@ -186,6 +191,34 @@ bool is_tensorrt_enabled() { return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); } +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +#if PY_MAJOR_VERSION < 3 +#define TRT_PY_TO_CPP_STRING PyString_AsString +#define TRT_CPP_TO_PY_STRING PyString_FromString +#else +#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 +#define TRT_CPP_TO_PY_STRING PyUnicode_FromString +#endif + +void clear_test_values(PyObject* pattern) { + tensorflow::tensorrt::test::ClearTestValues( + string(TRT_PY_TO_CPP_STRING(pattern))); +} + +void add_test_value(PyObject* label, PyObject* value) { + tensorflow::tensorrt::test::AddTestValue( + string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); +} + +PyObject* get_test_value(PyObject* label) { + string value = tensorflow::tensorrt::test::GetTestValue( + string(TRT_PY_TO_CPP_STRING(label))); + return TRT_CPP_TO_PY_STRING(value.c_str()); +} + %} std::pair<string, string> calib_convert( @@ -193,5 +226,9 @@ std::pair<string, string> calib_convert( version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(PyObject* pattern); +void add_test_value(PyObject* label, PyObject* value); +PyObject* get_test_value(PyObject* label); %unignoreall |