diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-07-30 00:27:58 -0700 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-07-30 00:27:58 -0700 |
commit | 1009f9de414365d0f2401c51b6e023374ad11ad6 (patch) | |
tree | 81b51f02196562a9f9f7e4075643c972aff93a0f /tensorflow/contrib/tensorrt | |
parent | b1e7f284443b6e0220ffd1d5ba728340c768649f (diff) |
Fix control dependency problems and add corresponding tests.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
22 files changed, 876 insertions, 457 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 033d5207f6..a1071d6749 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -85,11 +85,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":test_utils", ":trt_allocator", + ":trt_conversion", ":trt_logging", ":trt_plugins", ":trt_resources", - ":trt_conversion", ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -192,6 +193,7 @@ tf_py_wrap_cc( "//tensorflow/python:platform/base.i", ], deps = [ + ":test_utils", ":trt_conversion", ":trt_engine_op_kernel", "//third_party/python_runtime:headers", @@ -264,6 +266,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":test_utils", ":trt_allocator", ":trt_plugins", ":trt_logging", @@ -412,3 +415,12 @@ cc_library( hdrs = ["convert/utils.h"], copts = tf_copts(), ) + +cc_library( + name = "test_utils", + srcs = ["test/utils.cc"], + hdrs = ["test/utils.h"], + deps = [ + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 22909a199d..1e6300578d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -20,6 +20,7 @@ limitations under the License. #include <map> #include <set> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -49,9 +51,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT -#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT #include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA @@ -260,63 +262,6 @@ tensorflow::Status ConvertGraphDefToTensorRT( return ConvertAfterShapes(cp); } -bool IsUniformTensorValue(const tensorflow::TensorProto& tensor) { - using tensorflow::DataType; - switch (tensor.dtype()) { - case DataType::DT_HALF: // fall-through - case DataType::DT_BFLOAT16: - return tensor.half_val_size() == 1; - case DataType::DT_FLOAT: - return tensor.float_val_size() == 1; - case DataType::DT_DOUBLE: - return tensor.double_val_size() == 1; - case DataType::DT_INT32: // fall-through - case DataType::DT_INT16: // fall-through - case DataType::DT_INT8: // fall-through - case DataType::DT_UINT8: - return tensor.int_val_size() == 1; - case DataType::DT_STRING: - return tensor.string_val_size() == 1; - case DataType::DT_COMPLEX64: - return tensor.scomplex_val_size() == 1; - case DataType::DT_INT64: - return tensor.int64_val_size() == 1; - case DataType::DT_BOOL: - return tensor.bool_val_size() == 1; - case DataType::DT_COMPLEX128: - return tensor.dcomplex_val_size() == 1; - case DataType::DT_RESOURCE: - return tensor.resource_handle_val_size() == 1; - case DataType::DT_VARIANT: - return tensor.variant_val_size() == 1; - case DataType::DT_UINT32: - return tensor.uint32_val_size() == 1; - case DataType::DT_UINT64: - return tensor.uint64_val_size() == 1; - default: - return false; - } -} - -std::unordered_set<int> GetAttributeInputs(const tensorflow::Node* node) { - typedef std::unordered_map<string, std::unordered_set<int>> InputMap; - static const InputMap attribute_inputs = { - {"Concat", {0}}, {"ConcatV2", {-1}}, {"Reshape", {1}}}; - auto iter = attribute_inputs.find(node->type_string()); - if (iter != attribute_inputs.end()) { - // Apply reverse indexing - std::unordered_set<int> result; - for (int idx : iter->second) { - if (idx < 0) { - idx += node->num_inputs(); - } - result.insert(idx); - } - return result; - } - return {}; -} - // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, @@ -325,13 +270,10 @@ tensorflow::Status GetEngineInfo( const std::unordered_map<string, tensorflow::Node*>& node_map, const std::vector<tensorflow::Node*>& reverse_topo_order, EngineInfo* info) { - std::vector<int> subgraph_node_ids; + std::vector<int> subgraph_node_ids; // Topologically sorted node ids. + std::set<string> subgraph_node_names = segment_nodes; std::set<int> added_const_node_ids; // Used to prevent double insertion. std::set<string> segment_devices; - std::unordered_set<string> segment_consts; - std::vector<int> const_node_ids; - int input_port = 0; - int output_port = 0; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -339,7 +281,7 @@ tensorflow::Status GetEngineInfo( // input/output edges must be in different split of the graph. // TODO(aaroey): consider using node id and port instead. // TODO(aaroey): using topo order instead of reverting reverse topo order. - std::unordered_map<string, int> created_edges; + std::unordered_map<string, int> input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { const auto& node_name = (*it)->name(); @@ -358,133 +300,114 @@ tensorflow::Status GetEngineInfo( } const int node_id = node->id(); subgraph_node_ids.push_back(node_id); + // Create input connections. for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (input_node->IsSource()) continue; - if (segment_nodes.count(input_node->name()) == 0) { - // Add constant input node into the segment. We don't care if it has - // other output edges going into other engines or TF nodes. Since we add - // it only to the subsegment node list, not the subsegment itself, it - // won't be removed from the graph. If it doesn't have any edges, TF - // will prune it out. - if (input_node->type_string() == "Const") { - bool is_supported = input_node->output_type(0) == DT_FLOAT || - input_node->output_type(0) == DT_INT32; - bool is_attribute_input = - GetAttributeInputs(node).count(edge->dst_input()) != 0; - const tensorflow::TensorProto& tensor_proto = - input_node->def().attr().at("value").tensor(); - bool is_uniform = IsUniformTensorValue(tensor_proto); - - // Const can be absorbed - if (is_supported && is_attribute_input && is_uniform) { - if (segment_consts.count(input_node->name()) != 0) { - continue; // skip if already added - } - VLOG(0) << "Adding const node " << input_node->name(); - const_node_ids.push_back(input_node->id()); - segment_consts.insert(input_node->name()); - int conn_count = 0; - for (auto cinp_e : - input_node->in_edges()) { // must be Control edges - if (!cinp_e->IsControlEdge() || cinp_e->src()->IsSource()) { - conn_count++; - continue; - } - VLOG(0) << info->engine_name << ": Control edge " << conn_count - << " from node " << input_node->name() - << " edge= " << cinp_e->src()->name(); - auto cinp = cinp_e->src(); - EngineConnection ec(cinp->name(), cinp->id(), - cinp_e->src_output(), input_node->name(), - input_node->id(), cinp_e->dst_input(), true, - -1, true); - info->connections.emplace_back(std::move(ec)); - } - continue; - } + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back( + input_node->name(), input_node->id(), node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // Note that the constant data input must be supported by the engine + // regardless of the datatype, since the segmenter already removed + // unsupported data input nodes. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; } - - // Non-const data/control edge - if (!edge->IsControlEdge()) { - string s(input_node->name()); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Input edge = " << s; - int port = input_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - input_port++; - } - EngineConnection ec(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, port); - ec.connection_type = input_node->output_type(edge->src_output()); - info->connections.emplace_back(std::move(ec)); + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); +#if 1 + // Since we duplicate the const input node in both the segment graphdef + // and the engine, the segment node doesn't depend on it anymore, so we + // add a control dependency instead. + info->connections.emplace_back( + input_node->name(), input_node->id(), node_name, node_id, + /*input_edge=*/true); +#else + // Add control inputs to the const node as control input connections to + // the engine. + for (const auto const_in_edge : input_node->in_edges()) { + QCHECK(const_in_edge->IsControlEdge()); // Must be control edge. + auto const_in_node = const_in_edge->src(); + QCHECK(!segment_nodes.count(const_in_node->name())) + << "Loop found between segment and non-segment nodes, from " + "segment node " + << const_in_node->name() << " to non-segment node " + << input_node->name() << " to segment node " << node->name(); + if (const_in_node->IsSource()) continue; + VLOG(1) << "Control edge from node " << const_in_node->name() + << " to " << input_node->name(); + info->connections.emplace_back( + const_in_node->name(), const_in_node->id(), input_node->name(), + input_node->id(), /*input_edge=*/true); + } +#endif + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); } else { - EngineConnection ec(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, -1, true); - ec.connection_type = input_node->output_type(edge->src_output()); - info->connections.emplace_back(std::move(ec)); + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); } + info->connections.emplace_back(input_node->name(), input_node->id(), + edge->src_output(), node_name, node_id, + edge->dst_input(), /*input_edge=*/true, + port); } } - + // Create output connections. for (const auto edge : node->out_edges()) { auto output_node = edge->dst(); - if (output_node->IsSink()) continue; - if (segment_nodes.count(output_node->name()) == 0) { - if (!edge->IsControlEdge()) { - string s(node_name); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Output edge = " << s; - int port = output_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - output_port++; - } - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, port); - } else { - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, -1, true); - } + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; } - } - } - - // Fix control edges - for (size_t t = 0; t < info->connections.size(); t++) { - auto& conn = info->connections.at(t); - if (conn.is_control_edge) { - for (size_t k = 0; k < info->connections.size(); k++) { - if (k == t) continue; - const auto& other = info->connections.at(k); - if (conn.outside_id == other.outside_id && other.port_number != -1) { - VLOG(0) << "Updating control edge " << conn.outside_node_name - << " -> " << conn.inside_node_name << " to input port " - << other.port_number; - conn.port_number = other.port_number; - break; + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back( + output_node->name(), output_node->id(), node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); + VLOG(1) << "Output edge = " << s; + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); + } else { + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); } + info->connections.emplace_back(output_node->name(), output_node->id(), + edge->dst_input(), node_name, node_id, + edge->src_output(), /*input_edge=*/false, + port); } } - } + } // For each segment node in topological order. - // Construct the const nodes first - subgraph_node_ids.insert(subgraph_node_ids.begin(), const_node_ids.begin(), - const_node_ids.end()); + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_ids, &info->connections, - &info->segment_graph_def, &info->engine_name)); - info->engine_type = EngineInfo::EngineType::TRTStatic; - + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -502,36 +425,34 @@ tensorflow::Status GetEngineInfo( // Helper function to update edge connection from the removed node to the // engine node. If an outside node is gone, it must have been absorbed into // an engine node. Find the engine node. -void UpdateToEngineNode(tensorflow::Node*& node, string& node_name, int& port, - const std::vector<EngineInfo>& infos, - size_t my_engine_id, +void UpdateToEngineNode(const std::vector<EngineInfo>& infos, + const size_t my_engine_id, const std::vector<Node*>& engine_nodes, - bool update_input_edge) { - bool found_engine = false; + const bool is_input_edge, + const string& node_name, + tensorflow::Node** node, int* port) { for (size_t t = 0; t < infos.size(); ++t) { if (t == my_engine_id) { continue; } - auto& connected_eng_info = infos.at(t); - for (const auto& eng_conn : connected_eng_info.connections) { - if (update_input_edge && eng_conn.is_input_edge) { - continue; - } else if (!update_input_edge && !eng_conn.is_input_edge) { - continue; - } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; if (eng_conn.inside_node_name == node_name && - eng_conn.inside_port == port) { - node = engine_nodes[t]; - node_name = connected_eng_info.engine_name; - port = eng_conn.port_number; - found_engine = true; - break; + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; } } - if (found_engine) break; } - CHECK(found_engine); - CHECK(node != nullptr); + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; } // Function to insert a TRT engine node into the graph. @@ -539,114 +460,91 @@ void UpdateToEngineNode(tensorflow::Node*& node, string& node_name, int& port, // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] // 2. When an engine node is created, add it into the graph with necessary // re-wiring. -// 2.1. If the outside connected node is existing, connect the engine -// node to it. -// 2.2. If the outside connected node is gone, it must have been absorted -// into another engine node (which was processed before the processing -// one). Connect to the pre-existing engine node instead. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. // 3. In this way, we ensure the graph is topologically sort-able after each // invocation of CreateTRTNode(). - -tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, - const std::vector<EngineInfo>& infos, int pos, - tensorflow::Allocator* alloc, - int max_batch_size, - std::vector<Node*>& engine_nodes) { - auto& info = infos.at(pos); +tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, + nvinfer1::IGpuAllocator* alloc, + std::vector<Node*>* engine_nodes) { + const auto& info = infos.at(pos); + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); std::vector<tensorflow::TensorShapeProto> output_shape_protos; std::vector<tensorflow::TensorShapeProto> input_shape_protos; - std::vector<tensorflow::PartialTensorShape> shapes; + std::vector<tensorflow::PartialTensorShape> input_shapes; std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs; std::vector<tensorflow::Node*> input_nodes; std::vector<tensorflow::Node*> control_input_nodes; - std::vector<string> control_input_names; + std::unordered_set<string> control_input_names; std::vector<tensorflow::DataType> out_types; VLOG(1) << "Processing " << info.engine_name; - - // -- Preprocessing -- // - // collect needed info for creating the engine node in the graph - for (const auto conn : info.connections) { - // control edges - if (conn.is_control_edge) { - // skip control outputs for now. control output info are not needed for + // Collect needed info for creating the engine node in the graph + for (const auto& conn : info.connections) { + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for // node creation and will be processed later. - if (!conn.is_input_edge) { - continue; - } + if (!conn.is_input_edge) continue; - // control inputs + // Rewrire control input if it's not found in original graph. tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); - string input_node_name = conn.outside_node_name; int port = tensorflow::Graph::kControlSlot; if (!input_node) { - UpdateToEngineNode(input_node, input_node_name, port, infos, pos, - engine_nodes, true); - } - bool new_input = true; - for (const auto& name : control_input_names) { - if (name == input_node_name) { - new_input = false; - break; - } + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); } - if (new_input) { - control_input_nodes.push_back(input_node); - control_input_names.push_back(input_node_name); - - VLOG(1) << "Engine Control Input " << input_node_name << ":" << port - << " -> " << info.engine_name << ":" - << tensorflow::Graph::kControlSlot; + if (!control_input_names.insert(input_node->name()).second) { + continue; } - - // data edges + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() + << " -> " << info.engine_name; } else { - // data outputs + // Data edges if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. tensorflow::TensorShapeProto out_shape; - conn.inside_shape.AsProto( - &out_shape); // shape of the output node inside segment + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); if (output_shape_protos.size() <= conn.port_number) { output_shape_protos.resize(conn.port_number + 1); out_types.resize(conn.port_number + 1); } output_shape_protos.at(conn.port_number) = out_shape; out_types.at(conn.port_number) = conn.connection_type; - - // data input } else { + // Set the shapes and data types of input edge. tensorflow::TensorShapeProto in_shape; conn.outside_shape.AsProto(&in_shape); - if (input_shape_protos.size() <= conn.port_number) { input_shape_protos.resize(conn.port_number + 1); - shapes.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); } input_shape_protos.at(conn.port_number) = in_shape; - shapes.at(conn.port_number) = conn.outside_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + // Rewrire data input if it's not found in original graph. tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); - string input_node_name = conn.outside_node_name; - int input_port = conn.outside_port; - auto dtype = conn.connection_type; - + int port = conn.outside_port; if (!input_node) { - UpdateToEngineNode(input_node, input_node_name, input_port, infos, - pos, engine_nodes, true); - } - bool new_input = true; - for (const auto& inp : inputs) { - if (inp.node == input_node_name && inp.index == input_port) { - new_input = false; - break; - } + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); } - if (new_input) { - inputs.emplace_back(input_node_name, input_port, dtype); - CHECK(input_node != nullptr); - input_nodes.push_back(input_node); - - VLOG(1) << "Engine Input " << input_node_name << ":" << input_port + if (std::find_if(std::begin(inputs), std::end(inputs), + [input_node, &port]( + const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && + inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port << " -> " << info.engine_name << ":" << inputs.size() - 1; } } @@ -662,14 +560,12 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Otherwise we skip node creation for this engine. Logger trt_logger; TrtUniquePtrType<nvinfer1::ICudaEngine> engine; - std::unique_ptr<TRTDeviceAllocator> allocator( - new TRTDeviceAllocator(alloc)); // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( info.segment_graph_def, info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, - max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, - allocator.get(), /*calibrator=*/nullptr, &engine, + max_batch_size, info.max_workspace_size_bytes, input_shapes, + &trt_logger, alloc, /*calibrator=*/nullptr, &engine, /*convert_successfully=*/nullptr)); TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize()); segment_string = @@ -711,7 +607,7 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << ins; } node_builder.Input(inputs); - for (auto& c : control_input_names) { + for (const string& c : control_input_names) { node_builder.ControlInput(c); } @@ -744,54 +640,50 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); - engine_nodes[pos] = engine_node; + (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; return status; } - // input edges of the engine node - for (auto in : control_input_nodes) { + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { VLOG(1) << "Connecting control edge from " << in->name() << " to " << engine_node->name(); graph->AddControlEdge(in, engine_node); } - int idx = 0; VLOG(1) << "input_nodes size = " << input_nodes.size(); - for (auto in : inputs) { - Node* n = input_nodes[idx]; - CHECK(n != nullptr); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = input_nodes[i]; + const auto& in = inputs[i]; + CHECK_NOTNULL(n); VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index - << " to " << engine_node->name() << ":" << idx; - graph->AddEdge(n, in.index, engine_node, idx++); + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); } + // Updates the inputs of output edges destination nodes, and point them to the // engine node. - for (auto& conn : info.connections) { if (conn.is_input_edge) { continue; } - - string out_name = conn.outside_node_name; - auto out_node = graph->FindNodeId(conn.outside_id); - int out_port = conn.outside_port; - - if (!out_node) { - UpdateToEngineNode(out_node, out_name, out_port, infos, pos, engine_nodes, - false); + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); } - VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << out_node->name() << ":" << out_port; - - if (conn.is_control_edge) { - graph->AddControlEdge(engine_node, out_node); + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); } else { auto new_edge = - graph->AddEdge(engine_node, conn.port_number, out_node, out_port); - CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() - << ":" << conn.port_number << " -> " << out_node->name() - << ":" << conn.outside_port; + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; } } return status; @@ -1077,19 +969,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(&graph, engine_segments, i, device_alloc.second, - params.max_batch_size, engine_nodes); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat( + "Engine ", engine.engine_name, " creation for segment ", i, + ", composed of ", converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { + LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { graph.RemoveNode(node_map.at(node_name)); } } else { // Graph is not modified. - LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() - << " nodes failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 451d6fe698..3b0ac43061 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -22,6 +22,7 @@ limitations under the License. #include <memory> #include <set> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> @@ -2788,6 +2789,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set<string>& subgraph_node_names, const std::vector<int>& subgraph_node_ids, // In topological order std::vector<EngineConnection>* connections, tensorflow::GraphDef* segment_def, string* common_scope) { @@ -2796,6 +2798,7 @@ tensorflow::Status ConvertSegmentToGraphDef( // nodes in the segment graphdef. for (size_t i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); + if (connection.is_control_edge()) continue; auto outside_node = graph->FindNodeId(connection.outside_id); if (!outside_node) { // This should never happen, unless the original graph is problematic. @@ -2809,13 +2812,13 @@ tensorflow::Status ConvertSegmentToGraphDef( GetInputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); - + connection.outside_shape = partial_shape; } else { GetOutputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); + connection.inside_shape = partial_shape; } - connection.outside_shape = partial_shape; connection.connection_type = dtype; // Add dummy input/output nodes to the segment graphdef. @@ -2873,7 +2876,7 @@ tensorflow::Status ConvertSegmentToGraphDef( // Update the inputs of the new input nodes to point to placeholder nodes. for (int i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); - if (!connection.is_input_edge) continue; + if (connection.is_control_edge() || !connection.is_input_edge) continue; auto snode = segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); const string placeholder_name = @@ -2883,6 +2886,38 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + // Remove control inputs that are not inside the segment. + for (int i = 0; i < segment_def->node_size(); ++i) { + auto snode = segment_def->mutable_node(i); + const int input_size = snode->input_size(); + int input_idx = 0; + int actual_input_idx = 0; + while (input_idx < input_size) { + TensorId input = ParseTensorName(snode->input(input_idx)); + if (!subgraph_node_names.count( + string(input.first.data(), input.first.size())) && + !str_util::StartsWith(input.first, kInputPHName)) { + if (input.second == Graph::kControlSlot) { + VLOG(2) << "... removing control inputs " << input.first + << " from subgraph."; + ++input_idx; + continue; + } else { + return tensorflow::errors::InvalidArgument( + "Found non control input outside the segment that is not an " + "engine connection to ", snode->name(), ": ", input.first); + } + } + if (actual_input_idx != input_idx) { + snode->set_input(actual_input_idx, snode->input(input_idx)); + } + ++input_idx; + ++actual_input_idx; + } + for (int remove = input_size - actual_input_idx; remove > 0; --remove) { + snode->mutable_input()->RemoveLast(); + } + } *common_scope = local_scope; VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index d41a886b30..328efbf50c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -36,8 +36,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -static const char* kInputPHName = "InputPH_"; -static const char* kOutputPHName = "OutputPH_"; +static const char* kInputPHName = "TensorRTInputPH_"; +static const char* kOutputPHName = "TensorRTOutputPH_"; namespace convert { // TODO(aaroey): use an enum instead. @@ -46,9 +46,10 @@ const int FP16MODE = 1; const int INT8MODE = 2; struct EngineConnection { + // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, const string& inside, int in_id, int in_port, - bool input_edge, int port, bool control_edge = false) + bool input_edge, int port) : outside_node_name(outside), outside_id(out_id), outside_port(out_port), @@ -56,24 +57,39 @@ struct EngineConnection { inside_id(in_id), inside_port(in_port), is_input_edge(input_edge), - is_control_edge(control_edge), port_number(port) {} + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { + return port_number == Graph::kControlSlot; + } + const string outside_node_name; const int outside_id; const int outside_port; - tensorflow::PartialTensorShape outside_shape; + tensorflow::PartialTensorShape outside_shape; // Only set for input edge. const string inside_node_name; const int inside_id; const int inside_port; - tensorflow::PartialTensorShape inside_shape; + tensorflow::PartialTensorShape inside_shape; // Only set for output edge. tensorflow::DataType connection_type; - bool is_input_edge; - bool is_control_edge; - // The port number of the TRT node connecting to this edge. - int port_number; + const bool is_input_edge; + + // The port number of the TRT node connected with this edge. + const int port_number; }; struct EngineInfo { @@ -86,7 +102,9 @@ struct EngineInfo { string device; tensorflow::GraphDef segment_graph_def; - // The segment nodes that are on one side of the edges are topological sorted. + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. std::vector<EngineConnection> connections; enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; @@ -102,6 +120,7 @@ struct EngineInfo { // (OutputPH_*). This function needs to be called before TensorRT nodes // inserted in order to correctly get sizes from the original graph. // +// - subgraph_node_names: the node names of the subgraph. // - subgraph_node_ids: the node ids of the subgraph, must be sorted in // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be @@ -111,6 +130,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set<string>& subgraph_node_names, const std::vector<int>& subgraph_node_ids, std::vector<EngineConnection>* connections, tensorflow::GraphDef* segment_def, string* common_scope); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 6699b71d28..a19cd24c94 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -179,7 +180,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, helper](const tensorflow::Status& s) { + [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { @@ -189,6 +190,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); delete outputs; }); } @@ -234,6 +237,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } @@ -258,7 +262,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, ", current entries="); for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, "Requested batch=", num_batch); + StrAppend(&msg, " requested batch=", num_batch); LOG(WARNING) << msg; return -1; } @@ -276,7 +280,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } const int smallest_engine = GetEngineBatch(ctx); if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment"; + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -286,14 +291,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, auto& trt_engine_ptr = engine_ctx_pair.first; if (!trt_engine_ptr) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch - << " failed. Running native segment"; + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), engine_ctx_pair.second.get()); if (retry) { - LOG(WARNING) << "Failed to execute engine, retrying with native segment"; + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -412,6 +418,7 @@ bool TRTEngineOp::ExecuteTrtEngine( LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; } + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); // Synchronization will be done by TF. return !kRetry; } diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index fe4fa166a1..7cdfe2b1a6 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,7 +20,11 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph +from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 2b67931661..5c1f4a466e 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,9 +20,13 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values +from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index edd30ad7a9..9d14e635f4 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -20,17 +20,19 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment.""" @@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + expected_engines=["my_trt_op_0"], expected_output_dims=(100, 6, 6, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing multiple segment.""" @@ -95,32 +101,138 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - p = conv * c1 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - q = conv / c2 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + q = math_ops.div(conv, c2, name="div") - edge = self.trt_incompatible_op(q) - edge /= edge - r = edge + edge + edge = self.trt_incompatible_op(q, name="incompatible") + edge = math_ops.div(edge, edge, name="div1") + r = math_ops.add(edge, edge, name="add") - p -= edge - q *= edge - s = p + q - s -= r + p = math_ops.sub(p, edge, name="sub") + q = math_ops.mul(q, edge, name="mul1") + s = math_ops.add(p, q, name="add1") + s = math_ops.sub(s, r, name="sub1") array_ops.squeeze(s, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", "add", + # "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(100, 12, 12, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -# TODO(aaroey): add a large complex graph to test. +class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestA, self).setUp() + # Let it fail to build the first engine. + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + for i in range(2): + c = constant_op.constant(1.0, name="c%d" % i) + n = math_ops.add(n, c, name="add%d" % i) + n = math_ops.mul(n, n, name="mul%d" % i) + edge = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([edge]): + c = constant_op.constant(1.0, name="c2") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul2") + c = constant_op.constant(1.0, name="c3") + n = math_ops.add(n, c, name="add3") + n = math_ops.mul(n, n, name="mul3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + # Only the second engine is built. + "my_trt_op_1": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class PartiallyConvertedTestB(PartiallyConvertedTestA): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestB, self).setUp() + # Let it fail to build the second engine. + trt_convert.clear_test_values("") + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + return super(PartiallyConvertedTestB, self).GetParams()._replace( + expected_engines={ + # Only the first engine is built. + "my_trt_op_0": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + }) + + +class ConstInputTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + # Adds control dependency from the constant op to a trt incompatible op, + # and adds control dependency from the trt incompatible op to all other + # ops, to make sure the constant op cannot be contracted with any trt + # segment that depends on it. + with g.control_dependencies([c]): + d = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add2", "add3", "mul1"], + "my_trt_op_1": ["add", "add1", "mul"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 730b6843fb..2e1107e303 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, w1_name, w2_name], input_dims=[input_dims, w1_dims, w2_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(12, 5, 8, 7), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 0c03a10b64..8be32f59b4 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=7, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" + ], expected_output_dims=(48, 89), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index dd673463a5..9316b14da0 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=16, + expected_engines=[ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ], expected_output_dims=(5, 23040), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 8c51c45b0a..1874b9dd45 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(2, 126), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index 97b29bf05d..8c59000b70 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=['my_trt_op_0'], expected_output_dims=(5, 12, 12, 1), allclose_atol=1.e-02, allclose_rtol=1.e-02) diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 734ccf6345..fd55b8cd99 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 50265c0845..97e0d23b18 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -59,7 +59,7 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index bb7f5a77f0..5968af28ae 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -37,10 +38,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ - "gdef", "input_names", "input_dims", "num_expected_engines", + "gdef", "input_names", "input_dims", "expected_engines", "expected_output_dims", "allclose_atol", "allclose_rtol" ]) +RunParams = namedtuple( + "RunParams", + ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) + PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -48,6 +53,12 @@ def _IsQuantizationMode(mode): return mode == "INT8" +class GraphState: + ORIGINAL = 0 + CALIBRATE = 1 + INFERENCE = 2 + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -63,34 +74,79 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def precision_modes(self): return ["FP32", "FP16", "INT8"] + # str is bytes in py2, but unicode in py3. + def _ToUnicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return s + return s.decode("utf-8") + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + def _ToBytes(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: - return s.encode("utf-8") + if isinstance(s, str): + return s.encode("utf-8") + return s def _ToString(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: + if isinstance(s, str): + return s return s.decode("utf-8") + @classmethod + def setUpClass(cls): + """Setup method for the module.""" + super(TfTrtIntegrationTestBase, cls).setUpClass() + trt_convert.enable_test_value() + def setUp(self): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() warnings.simplefilter("always") + trt_convert.clear_test_values("") def GetParams(self): """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" raise NotImplementedError() - def _GetConfigProto(self, - params, - use_optimizer, - precision_mode=None, - is_dynamic_op=None): + def _PrepareRun(self, params, graph_state): + """Set up necessary testing environment before calling sess.run().""" + # Clear test values added by TRTEngineOp. + trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + + def _VerifyRun(self, params, graph_state): + """Verify the state after sess.run().""" + for engine_name in params.expected_engines: + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + + def _GetConfigProto(self, params, run_params, graph_state): """Get config proto based on specific settings.""" - if use_optimizer: + if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: rewriter_cfg = rewriter_config_pb2.RewriterConfig() rewriter_cfg.optimizers.extend(["constfold", "layout"]) custom_op = rewriter_cfg.custom_optimizers.add() @@ -98,14 +154,31 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): custom_op.parameter_map["minimum_segment_size"].i = 3 custom_op.parameter_map["max_batch_size"].i = max( [dims[0] for dims in params.input_dims]) - custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op + custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 custom_op.parameter_map["precision_mode"].s = self._ToBytes( - precision_mode) + run_params.precision_mode) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() + # Disable all other optimizations which can affect the converted graph. + off = rewriter_config_pb2.RewriterConfig.OFF + graph_options.optimizer_options.opt_level = config_pb2.OptimizerOptions.L0 + graph_options.rewrite_options.layout_optimizer = off + graph_options.rewrite_options.constant_folding = off + graph_options.rewrite_options.shape_optimization = off + graph_options.rewrite_options.remapping = off + graph_options.rewrite_options.arithmetic_optimization = off + graph_options.rewrite_options.dependency_optimization = off + graph_options.rewrite_options.loop_optimization = off + graph_options.rewrite_options.function_optimization = off + graph_options.rewrite_options.debug_stripper = off + graph_options.rewrite_options.disable_model_pruning = True + graph_options.rewrite_options.scoped_allocator_optimization = off + graph_options.rewrite_options.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) + gpu_options = config_pb2.GPUOptions() gpu_options.allow_growth = True if trt_convert.get_linked_tensorrt_version()[0] == 3: @@ -115,7 +188,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): gpu_options=gpu_options, graph_options=graph_options) return config - def _RunGraph(self, params, gdef, input_data, config, num_runs=2): + def _ExpectTestValue(self, engine_name, method, value): + self.assertEqual( + value, trt_convert.get_test_value("%s:%s" % (engine_name, method))) + + def _ExpectCalibration(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteCalibration", value) + + def _ExpectTrtEngine(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value) + + def _ExpectNativeSegment(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) + + def _RunGraph(self, params, gdef, input_data, config, graph_state, + num_runs=2): """Run given graphdef multiple times.""" assert len(params.input_names) == len(input_data) g = ops.Graph() @@ -132,93 +219,166 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): val = None # Defaults to 2 runs to verify result across multiple runs is same. for _ in range(num_runs): + self._PrepareRun(params, graph_state) new_val = sess.run(out, {inp[i]: input_data[i] for i in range(len(inp))}) self.assertEqual(params.expected_output_dims, new_val.shape) if val is not None: self.assertAllEqual(val, new_val) val = new_val + self._VerifyRun(params, graph_state) return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. def _RunCalibration(self, params, gdef, input_data, config): """Run calibration on given graph.""" - return self._RunGraph(params, gdef, input_data, config, 30) + return self._RunGraph( + params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op): + def _GetTrtGraphDef(self, params, run_params, gdef): """Return trt converted graphdef.""" return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=[self.output_name], max_batch_size=max([dims[0] for dims in params.input_dims]), max_workspace_size_bytes=1 << 25, - precision_mode=precision_mode, + precision_mode=run_params.precision_mode, minimum_segment_size=2, - is_dynamic_op=is_dynamic_op) - - def _VerifyGraphDef(self, - params, - gdef, - precision_mode=None, - is_calibrated=None, - dynamic_engine=None): + is_dynamic_op=run_params.dynamic_engine) + + def _WriteGraph(self, params, run_params, gdef, graph_state): + if graph_state == GraphState.ORIGINAL: + label = "Original" + elif graph_state == GraphState.CALIBRATE: + label = "CalibEngine" + elif graph_state == GraphState.INFERENCE: + label = "InferEngine" + graph_name = ( + self.__class__.__name__ + "_" + run_params.test_name + "_" + label + + ".pbtxt") + logging.info("Writing graph to %s/%s", self.get_temp_dir(), graph_name) + graph_io.write_graph(gdef, self.get_temp_dir(), graph_name) + + def _VerifyConnections(self, params, converted_gdef): + old_to_new_node_map = { + self._ToString(n.name): self._ToString(n.name) for n in params.gdef.node + } + for engine_name, node_names in params.expected_engines.iteritems(): + for n in node_names: + old_to_new_node_map[n] = engine_name + name_to_node_map = {self._ToString(n.name): n for n in params.gdef.node} + + def input_name(inp): + inp = self._ToString(inp) + prefix = "" + if inp[0] == "^": + prefix = "^" + inp = inp[1:] + parts = inp.split(":") + if len(parts) > 1 and parts[-1].isdigit(): + inp = inp[:-len(parts[-1]) - 1] + return (prefix, inp) + + expected_input_map = {} + for n in params.gdef.node: + name_str = self._ToString(n.name) + target_node_name = old_to_new_node_map[name_str] + is_engine_op = (target_node_name != name_str) + if target_node_name not in expected_input_map: + expected_input_map[target_node_name] = set() + input_set = expected_input_map[target_node_name] + for inp in n.input: + (prefix, inp_name) = input_name(inp) + # Add the input only if it's outside the segment (note that it could be + # in a different engine). + if (not is_engine_op or + old_to_new_node_map[inp_name] != target_node_name): + if is_engine_op and name_to_node_map[inp_name].op == "Const": + # Const data input nodes to the segment has been copied to the + # segment graphdef and the engine, and the dependency has been + # converted to control dependendy. + input_set.add("^" + old_to_new_node_map[inp_name]) + else: + input_set.add(prefix + old_to_new_node_map[inp_name]) + + actual_input_map = {} + for n in converted_gdef.node: + name_str = self._ToString(n.name) + actual_input_map[name_str] = set() + input_set = actual_input_map[name_str] + for inp in n.input: + (prefix, node_name) = input_name(inp) + input_set.add(prefix + node_name) + + self.assertEqual( + expected_input_map, + actual_input_map, + msg="expected:\n%s\nvs actual:\n%s" % (expected_input_map, + actual_input_map)) + + def _VerifyGraphDef(self, params, run_params, gdef, graph_state): + self._WriteGraph(params, run_params, gdef, graph_state) + num_engines = 0 for n in gdef.node: - # TODO(jie): we should have coverage for failed conversion (TF fallback). - # where the conversion will fail and we shouldn't count this engine as the - # converted engines. if n.op == "TRTEngineOp": num_engines += 1 - self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s) - self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s) + self.assertTrue(n.name in params.expected_engines) + self.assertTrue(len(n.attr["serialized_segment"].s)) + self.assertTrue(len(n.attr["segment_funcdef_name"].s)) self.assertEqual( - self._ToBytes(precision_mode), n.attr["precision_mode"].s) - self.assertEqual(not dynamic_engine, n.attr["static_engine"].b) - if _IsQuantizationMode(precision_mode) and is_calibrated: - self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s) + self._ToBytes(run_params.precision_mode), + n.attr["precision_mode"].s) + + is_dynamic_engine = not n.attr["static_engine"].b + self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) + + has_calibration_data = len(n.attr["calibration_data"].s) + if (_IsQuantizationMode(run_params.precision_mode) and + graph_state == GraphState.INFERENCE): + self.assertTrue(has_calibration_data) else: - self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s) - if precision_mode is None: # This means gdef is the original GraphDef. + self.assertFalse(has_calibration_data) + if graph_state == GraphState.ORIGINAL: self.assertEqual(0, num_engines) else: - self.assertEqual(num_engines, params.num_expected_engines) + self.assertEqual(num_engines, len(params.expected_engines)) + if isinstance(params.expected_engines, dict): + self._VerifyConnections(params, gdef) + # TODO(aaroey): consider verifying the corresponding TF function. - def RunTest(self, params, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine): - assert precision_mode in PRECISION_MODES + def RunTest(self, params, run_params): + assert run_params.precision_mode in PRECISION_MODES input_data = [np.random.random_sample(dims) for dims in params.input_dims] input_gdef = params.gdef - self._VerifyGraphDef(params, input_gdef) + self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) # Get reference result without running trt. - config_no_trt = self._GetConfigProto(params, False) + config_no_trt = self._GetConfigProto(params, run_params, + GraphState.ORIGINAL) logging.info("Running original graph w/o trt, config:\n%s", str(config_no_trt)) - ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt) + ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, + GraphState.ORIGINAL) # Run calibration if necessary. - if _IsQuantizationMode(precision_mode): + if _IsQuantizationMode(run_params.precision_mode): - calib_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_calib_engine) + calib_config = self._GetConfigProto(params, run_params, + GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) - if use_optimizer: - self.assertTrue(False) - # TODO(aaroey): uncomment this and get infer_gdef when this mode is - # supported. - # result = self._RunCalibration(params, input_gdef, input_data, - # calib_config) + if run_params.use_optimizer: + result = self._RunCalibration(params, input_gdef, input_data, + calib_config) else: - calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode, - dynamic_calib_engine) - self._VerifyGraphDef(params, calib_gdef, precision_mode, False, - dynamic_calib_engine) + calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) + self._VerifyGraphDef(params, run_params, calib_gdef, + GraphState.CALIBRATE) result = self._RunCalibration(params, calib_gdef, input_data, calib_config) - infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) - self._VerifyGraphDef(params, infer_gdef, precision_mode, True, - dynamic_calib_engine) + infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -229,18 +389,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): infer_gdef = input_gdef # Run inference. - infer_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_infer_engine) + infer_config = self._GetConfigProto(params, run_params, + GraphState.INFERENCE) logging.info("Running final inference graph, config:\n%s", str(infer_config)) - if use_optimizer: - result = self._RunGraph(params, infer_gdef, input_data, infer_config) + if run_params.use_optimizer: + result = self._RunGraph(params, infer_gdef, input_data, infer_config, + GraphState.INFERENCE) else: - trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode, - dynamic_infer_engine) - self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True, - dynamic_infer_engine) - result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config) + trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) + self._VerifyGraphDef(params, run_params, trt_infer_gdef, + GraphState.INFERENCE) + result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, + GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -263,66 +424,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _AddTests(test_class): """Adds test methods to TfTrtIntegrationTestBase.""" - def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine): + def _GetTest(run_params): """Gets a single test method based on the parameters.""" def _Test(self): params = self.GetParams() logging.info( - "Running test with parameters: use_optimizer=%s, precision_mode=%s, " - "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer, - precision_mode, dynamic_infer_engine, dynamic_calib_engine) - self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine) + "Running test %s with parameters: use_optimizer=%s, " + "precision_mode=%s, dynamic_engine=%s", + "testTfTRT_" + run_params.test_name, run_params.use_optimizer, + run_params.precision_mode, run_params.dynamic_engine) + self.RunTest(params, run_params) return _Test use_optimizer_options = [False, True] - dynamic_infer_engine_options = [False, True] - dynamic_calib_engine_options = [False, True] - for (use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine) in itertools.product( - use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options, - dynamic_calib_engine_options): + dynamic_engine_options = [False, True] + for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( + use_optimizer_options, PRECISION_MODES, dynamic_engine_options): if _IsQuantizationMode(precision_mode): - if not dynamic_calib_engine and dynamic_infer_engine: - # TODO(aaroey): test this case, the conversion from static calibration - # engine to dynamic inference engine should be a noop. - continue if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_calib_engine: + if not dynamic_engine: # TODO(aaroey): construction of static calibration engine is not # supported yet. continue - if dynamic_calib_engine and not dynamic_infer_engine: - # TODO(aaroey): construction of static inference engine using dynamic - # calibration engine is not supported yet. - continue - else: # In non int8 mode. - if dynamic_calib_engine: - # dynamic_calib_engine doesn't affect non-int8 modes, so just let - # related tests run once on dynamic_calib_engine=False. - continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - infer_engine_type = ("DynamicInferEngine" - if dynamic_infer_engine else "StaticInferEngine") - calib_engine_type = "" - if precision_mode == "INT8": - calib_engine_type = ("DynamicCalibEngine" - if dynamic_calib_engine else "StaticCalibEngine") - test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type, - ("_" + calib_engine_type) - if len(calib_engine_type) else "") - setattr( - test_class, "testTfTRT_" + test_name, - _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine)) + engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") + test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + run_params = RunParams( + use_optimizer=use_optimizer, + precision_mode=precision_mode, + dynamic_engine=dynamic_engine, + test_name=test_name) + setattr(test_class, "testTfTRT_" + test_name, _GetTest(run_params)) if trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index b9e977cf67..500057a36d 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, input2_name], input_dims=[input_dims, input2_dims], - num_expected_engines=5, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ], expected_output_dims=(12, 5, 8, 12), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc new file mode 100644 index 0000000000..319ddea1b7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/test/utils.h" + +#include <unordered_map> +#include <vector> + +#include "re2/re2.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// TODO(aaroey): make this class thread-safe. +class TestValueManager { + public: + static TestValueManager* singleton() { + static TestValueManager* manager = new TestValueManager(); + return manager; + } + + void Enable() { + VLOG(1) << "Enabling test value"; + enabled_ = true; + } + + void Add(const string& label, const string& value) { + if (TF_PREDICT_FALSE(enabled_)) { + QCHECK_NE("", value); + VLOG(1) << "Adding test value: " << label << " -> " << value; + values_.insert({label, value}); + } + } + + string Get(const string& label) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Getting test value by " << label; + auto itr = values_.find(label); + if (itr == values_.end()) return ""; + return itr->second; + } + return ""; + } + + void Clear(const string& pattern) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Clearing test values"; + if (pattern == "") { + values_.clear(); + return; + } + std::vector<string> keys_to_clear; + for (const auto& kv : values_) { + if (RE2::FullMatch(kv.first, pattern)) { + keys_to_clear.push_back(kv.first); + } + } + for (const string& key : keys_to_clear) { + values_.erase(key); + } + } + } + + private: + TestValueManager() : enabled_(false) {} + + bool enabled_; + std::unordered_map<string, string> values_; +}; + +void EnableTestValue() { TestValueManager::singleton()->Enable(); } + +void ClearTestValues(const string& pattern) { + TestValueManager::singleton()->Clear(pattern); +} + +void AddTestValue(const string& label, const string& value) { + TestValueManager::singleton()->Add(label, value); +} + +string GetTestValue(const string& label) { + return TestValueManager::singleton()->Get(label); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h new file mode 100644 index 0000000000..625cd3d799 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// Helper methods to inject values used by testing tools. +void EnableTestValue(); +void ClearTestValues(const string& pattern); +void AddTestValue(const string& label, const string& value); +string GetTestValue(const string& label); + +#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ + do { \ + if (::tensorflow::tensorrt::test::GetTestValue(label) == value_to_return) {\ + return errors::Internal("Injected manually"); \ + } \ + } while(0) + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index 2b134c3bce..ab4d224db4 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 6, 2, 2), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index bec2f23eff..56bdf848ea 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 2, 2, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 422740fdf6..921c263dfe 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -101,6 +101,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall @@ -110,6 +111,10 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; %unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ @@ -251,6 +256,22 @@ bool is_tensorrt_enabled() { return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); } +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +void clear_test_values(string pattern) { + tensorflow::tensorrt::test::ClearTestValues(pattern); +} + +void add_test_value(string label, string value) { + tensorflow::tensorrt::test::AddTestValue(label, value); +} + +string get_test_value(string label) { + return tensorflow::tensorrt::test::GetTestValue(label); +} + %} std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op); @@ -266,5 +287,9 @@ std::pair<string, string> trt_convert(string graph_def_string, version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(string pattern); +void add_test_value(string label, string value); +string get_test_value(string label); %unignoreall |