diff options
author | 2018-07-12 15:22:03 -0700 | |
---|---|---|
committer | 2018-07-12 15:22:03 -0700 | |
commit | 86f632e29810fa93db559f882567b9569dabfad5 (patch) | |
tree | 542d32c630c461ebcdbc483644c8585b92743aec | |
parent | 571d3dc5747e04fe0a80be185e64532cf74e1fb0 (diff) |
Implement the input/output edge validaters
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 46 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 165 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.h | 26 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.cc | 8 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment_test.cc | 21 |
6 files changed, 182 insertions, 88 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 359fac36f5..ba01eaabc2 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -107,10 +107,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - if (!candidate_ops.count(node->type_string()) && - !PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())) { - return false; - } + return (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } tensorflow::Status BuildNodeMap( @@ -280,7 +278,8 @@ tensorflow::Status GetEngineInfo( subgraph_node_ids.push_back(node_id); for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (segment_nodes.count(input_node->name()) == 0) { + if (segment_nodes.count(input_node->name()) == 0 && + !edge->IsControlEdge() && !input_node->IsSource()) { // Add constant input node into the segment. We don't care if it has // other output edges going into other engines or TF nodes. Since we add // it only to the subsegment node list, not the subsegment itself, it @@ -288,7 +287,7 @@ tensorflow::Status GetEngineInfo( // will prune it out. if (input_node->type_string() == "Const") { subgraph_node_ids.push_back(input_node->id()); - } else if (!edge->IsControlEdge() && !input_node->IsSource()) { + } else { string s(input_node->name()); StrAppend(&s, ":", edge->src_output()); VLOG(1) << "Input edge = " << s; @@ -351,9 +350,9 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, nvinfer1::IGpuAllocator* alloc, int max_batch_size) { const auto& info = infos.at(pos); - std::vector<tensorflow::TensorShapeProto> out_shapes; - std::vector<tensorflow::TensorShapeProto> input_shapes; - std::vector<tensorflow::PartialTensorShape> shapes; + std::vector<tensorflow::TensorShapeProto> output_shape_protos; + std::vector<tensorflow::TensorShapeProto> input_shape_protos; + std::vector<tensorflow::PartialTensorShape> input_shapes; std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs; std::vector<tensorflow::DataType> out_types; VLOG(1) << "Processing " << info.engine_name; @@ -366,11 +365,11 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, tensorflow::TensorShapeProto out_shape; // shape of the output node inside segment conn.inside_shape.AsProto(&out_shape); - if (out_shapes.size() <= conn.port_number) { - out_shapes.resize(conn.port_number + 1); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); out_types.resize(conn.port_number + 1); } - out_shapes.at(conn.port_number) = out_shape; + output_shape_protos.at(conn.port_number) = out_shape; out_types.at(conn.port_number) = conn.connection_type; continue; } @@ -378,12 +377,12 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Set the shapes and data types of input edge. tensorflow::TensorShapeProto in_shape; conn.outside_shape.AsProto(&in_shape); - if (input_shapes.size() <= conn.port_number) { + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); input_shapes.resize(conn.port_number + 1); - shapes.resize(conn.port_number + 1); } - input_shapes.at(conn.port_number) = in_shape; - shapes.at(conn.port_number) = conn.outside_shape; + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; string input_node = conn.outside_node_name; int input_port = conn.outside_port; @@ -411,6 +410,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " << info.engine_name << ":" << inputs.size(); // Skip duplicate inputs. + // TODO(aaroey): use std::find instead. GetEngineInfo already remove + // duplicate connections, so here we should never find any duplicate? bool new_input = true; for (const auto& inp : inputs) { if (inp.node == input_node && inp.index == input_port) { @@ -438,8 +439,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( info.segment_graph_def, info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, - max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, - alloc, /*calibrator=*/nullptr, &engine, + max_batch_size, info.max_workspace_size_bytes, input_shapes, + &trt_logger, alloc, /*calibrator=*/nullptr, &engine, /*convert_successfully=*/nullptr)); TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize()); segment_string = @@ -487,8 +488,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, } tensorflow::NodeDef trt_node; tensorflow::Status status = - node_builder.Attr("input_shapes", input_shapes) - .Attr("output_shapes", out_shapes) + node_builder.Attr("input_shapes", input_shape_protos) + .Attr("output_shapes", output_shape_protos) .Attr("static_engine", info.engine_type == EngineInfo::EngineType::TRTStatic) .Attr("segment_funcdef_name", @@ -705,6 +706,7 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator( } // Entry function from optimization pass. +// TODO(aaeory): parameter should use pointer type. tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), @@ -722,8 +724,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, IsTensorRTInputCandidate, - IsTensorRTOutputCandidate, segment_options, &initial_segments)); + &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), + OutputEdgeValidator(), segment_options, &initial_segments)); if (initial_segments.size() > 1) { VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << initial_segments.size(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 8f6656e4ad..c49e26ea4e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include <algorithm> +#include <cstring> #include <list> #include <map> #include <memory> @@ -57,7 +58,6 @@ namespace tensorflow { namespace tensorrt { namespace convert { using ::tensorflow::str_util::Split; - using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -77,11 +77,63 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, break; default: return tensorflow::errors::InvalidArgument( - "Unsupported data type " + tensorflow::DataTypeString(tf_dtype)); + "Unsupported data type ", tensorflow::DataTypeString(tf_dtype)); } return tensorflow::Status::OK(); } +void GetInputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int out_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasOutputProperties(outside_node->name())) { + auto output_params = + graph_properties.GetOutputProperties(outside_node->name()); + auto out_shape = output_params.at(out_port); + *dtype = out_shape.dtype(); + *shape = out_shape.shape(); + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + *dtype = outside_node->output_type(out_port); + } +} + +void GetOutputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int in_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasInputProperties(outside_node->name())) { + auto input_params = + graph_properties.GetInputProperties(outside_node->name()); + auto in_shape = input_params.at(in_port); + *dtype = in_shape.dtype(); + *shape = in_shape.shape(); + } else { + *dtype = outside_node->input_type(in_port); + } +} + +tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, + const tensorflow::DataType dtype, + nvinfer1::DataType* trt_dtype) { + TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + if (shape.dims() < 0) { + return tensorflow::errors::InvalidArgument( + "Input tensor rank is unknown."); + } + if (shape.dims() > 8) { + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8."); + } + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) < 0) { + return tensorflow::errors::InvalidArgument( + "Input tensor has a unknow non-batch dimemension at dim ", d); + } + } + return Status::OK(); +} + inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -2177,25 +2229,22 @@ tensorflow::Status ConvertGraphDefToEngine( (node_def.op() == "Placeholder")) { nvinfer1::DimsCHW input_dim_pseudo_chw; for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = - ConvertDType(node_def.attr().at("dtype").type(), &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, - &slot_number)) { - LOG(ERROR) << "Failed to parse slot number from " << node_name - << " +8= " << node_name.c_str() + 8; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kInputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); } + nvinfer1::DataType dtype; auto shape = input_shapes.at(slot_number); - if (shape.dims() > 8) { - LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name - << " at input slot " << slot_number; - return tensorflow::errors::OutOfRange( - "Input tensor rank is greater than 8"); + auto status = ValidateInputProperties( + shape, node_def.attr().at("dtype").type(), &dtype); + if (!status.ok()) { + const string error_message = StrCat( + "Validation failed for ", node_name, " and input slot ", + slot_number, ": ", status.error_message()); + LOG(WARNING) << error_message; + return Status(status.code(), error_message); } if (VLOG_IS_ON(1)) { string dim_str("dims="); @@ -2226,10 +2275,10 @@ tensorflow::Status ConvertGraphDefToEngine( } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, - &slot_number)) { - LOG(ERROR) << "Failed to parse slot number from " << node_name - << " +9=" << node_name.c_str() + 9; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kOutputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); } if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); @@ -2288,38 +2337,20 @@ tensorflow::Status ConvertSegmentToGraphDef( "Cannot find node with id ", connection.outside_id, " in the graph."); } // Updates the shape and data types of input/output connections. - tensorflow::DataType input_type = tensorflow::DT_FLOAT; + tensorflow::DataType dtype; tensorflow::PartialTensorShape partial_shape; if (connection.is_input_edge) { - if (graph_properties.HasOutputProperties(connection.outside_node_name)) { - auto output_params = - graph_properties.GetOutputProperties(connection.outside_node_name); - auto out_shape = output_params.at(connection.outside_port); - input_type = out_shape.dtype(); - std::vector<tensorflow::int64> dims; - partial_shape = out_shape.shape(); - connection.outside_shape = partial_shape; - } else { - VLOG(0) << "Unknown output shape" << outside_node->name(); - input_type = graph->FindNodeId(connection.outside_id) - ->output_type(connection.outside_port); - } - connection.connection_type = input_type; - - } else { // output edge - if (graph_properties.HasInputProperties(connection.outside_node_name)) { - auto input_params = - graph_properties.GetInputProperties(connection.outside_node_name); - auto in_shape = input_params.at(connection.outside_port); - input_type = in_shape.dtype(); - partial_shape = in_shape.shape(); - connection.inside_shape = partial_shape; - } else { - input_type = graph->FindNodeId(connection.inside_id) - ->output_type(connection.outside_port); - } - connection.connection_type = input_type; + GetInputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); + + } else { + GetOutputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); } + connection.outside_shape = partial_shape; + connection.connection_type = dtype; // Add dummy input/output nodes to the segment graphdef. if (connection.is_input_edge) { @@ -2335,7 +2366,7 @@ tensorflow::Status ConvertSegmentToGraphDef( auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); auto status = builder.Attr("shape", partial_shape) - .Attr("dtype", input_type) + .Attr("dtype", dtype) .Finalize(seg_node); VLOG(1) << "Constructing input " << node_name << " for the edge " << connection.outside_node_name << ":" << connection.outside_port @@ -2353,7 +2384,7 @@ tensorflow::Status ConvertSegmentToGraphDef( marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Identity"); - auto status = builder.Input(connection.inside_node_name, 0, input_type) + auto status = builder.Input(connection.inside_node_name, 0, dtype) .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port @@ -2391,11 +2422,35 @@ tensorflow::Status ConvertSegmentToGraphDef( return tensorflow::Status::OK(); } -bool IsTensorRTInputCandidate(const tensorflow::Node* node) { +bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { + if (in_edge->IsControlEdge()) return true; + PartialTensorShape shape; + tensorflow::DataType dtype; + GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), + &shape, &dtype); + nvinfer1::DataType trt_dtype; + Status status = ValidateInputProperties(shape, dtype, &trt_dtype); + if (!status.ok()) { + VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + << ": " << status; + return false; + } + if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { + VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + << " which has an input at port " << in_edge->dst_input() + << " with #dim<3 and is not a const: " << shape; + return false; + } return true; } -bool IsTensorRTOutputCandidate(const tensorflow::Node* node) { +bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { + if (out_edge->IsControlEdge()) return true; + if (out_edge->src()->type_string() == "Const") { + VLOG(2) << "--> Need to remove output node " << out_edge->src()->name() + << " which is a Const."; + return false; + } return true; } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 872ba6a080..64337eee84 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -104,6 +104,8 @@ struct EngineInfo { // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// +// TODO(aaroey): add tests to validate these properties. tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, @@ -128,9 +130,29 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool* convert_successfully); -bool IsTensorRTInputCandidate(const tensorflow::Node* node); +// Helper class for the segmenter to determine whether an input edge to the TRT +// segment is valid. +class InputEdgeValidator { + public: + InputEdgeValidator(const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} + + // Return true if the specified edge is eligible to be an input edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* in_edge) const; -bool IsTensorRTOutputCandidate(const tensorflow::Node* node); + private: + const grappler::GraphProperties& graph_properties_; +}; + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* out_edge) const; +}; } // namespace convert } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 5c0898b29a..92807bed14 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -364,8 +364,8 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, const std::function<bool(const tensorflow::Node*)>& candidate_fn, - const std::function<bool(const tensorflow::Node*)>& input_candidate_fn, - const std::function<bool(const tensorflow::Node*)>& output_candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { // Steps: // 1. run the segmentation algorithm to find all the segments, which uses @@ -526,7 +526,7 @@ tensorflow::Status SegmentGraph( for (const tensorflow::Edge* edge : node->in_edges()) { if (!edge->IsControlEdge() && !edge->src()->IsSource() && !segment_nodes.count(edge->src())) { // 'node' is an input node. - if (!input_candidate_fn(node)) { + if (!input_candidate_fn(edge)) { in_nodes_que.push_back(node); added = true; break; @@ -537,7 +537,7 @@ tensorflow::Status SegmentGraph( for (const tensorflow::Edge* edge : node->out_edges()) { if (!edge->dst()->IsSink() && !edge->IsControlEdge() && !segment_nodes.count(edge->dst())) { // 'node' is an output node. - if (!output_candidate_fn(node)) { + if (!output_candidate_fn(edge)) { out_nodes_que.push_back(node); break; } diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index ab75135054..8c44eb782a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -52,8 +52,8 @@ struct SegmentOptions { tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, const std::function<bool(const tensorflow::Node*)>& candidate_fn, - const std::function<bool(const tensorflow::Node*)>& input_candidate_fn, - const std::function<bool(const tensorflow::Node*)>& output_candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index a43cf4f416..432e7b1c04 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -41,15 +41,30 @@ class SegmentTest : public ::testing::Test { }; } + std::function<bool(const tensorflow::Edge*)> MakeInputEdgeCandidateFn( + const std::set<string>& node_names) { + return [node_names](const tensorflow::Edge* in_edge) -> bool { + return node_names.find(in_edge->dst()->name()) != node_names.end(); + }; + } + + std::function<bool(const tensorflow::Edge*)> MakeOutputEdgeCandidateFn( + const std::set<string>& node_names) { + return [node_names](const tensorflow::Edge* out_edge) -> bool { + return node_names.find(out_edge->src()->name()) != node_names.end(); + }; + } + void RunTest(const tensorflow::Graph* graph, const std::set<string>& candidates, const std::set<string>& input_candidates, const std::set<string>& output_candidates, const std::vector<std::set<string>>& expected_segments) { SegmentNodesVector segments; - TF_EXPECT_OK(SegmentGraph( - graph, MakeCandidateFn(candidates), MakeCandidateFn(input_candidates), - MakeCandidateFn(output_candidates), default_options_, &segments)); + TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates), + MakeInputEdgeCandidateFn(input_candidates), + MakeOutputEdgeCandidateFn(output_candidates), + default_options_, &segments)); ValidateSegment(segments, expected_segments); } |