diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 97 |
1 files changed, 72 insertions, 25 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 96e0700862..4e4d295538 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -362,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()), - istrides, static_cast<Eigen::half*>( - const_cast<void*>(oweights->GetValues())), - ostrides); + Reorder2( + {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()), + istrides, + static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())), + ostrides); break; } default: @@ -1179,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast<nvinfer1::ITensor*>(tensor_l), @@ -2138,9 +2139,7 @@ void Converter::register_op_converters() { } } // namespace -tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { - return tensorflow::errors::Unimplemented("Not implemented yet"); -} + tensorflow::Status ConvertCalibrationNodeToEngineNode( tensorflow::Graph& graph, tensorflow::Node* c_node) { const auto ndef = c_node->def(); @@ -2164,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( for (auto n : graph.op_nodes()) { node_maps.insert({n->name(), n}); } + std::set<int> subgraph_ids; + for (const auto internal_node : segment_nodes) { + subgraph_ids.insert(node_maps.at(internal_node)->id()); + } + if (VLOG_IS_ON(2)) { + string node_names = StrCat(c_node->name(), " segment nodes= "); + + for (const auto& node_name : segment_nodes) { + StrAppend(&node_names, node_name, ", "); + } + VLOG(2) << node_names; + } + VLOG(1) << "Output Nodes:"; std::vector<tensorflow::DataType> out_types; std::vector<const tensorflow::Edge*> out_edges; + for (auto& i : output_nodes) { auto node_port = tensorflow::str_util::Split(i, ":"); VLOG(1) << " " << i << " in graph " << node_maps.count(i); @@ -2186,18 +2199,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( out_types.push_back(out_node->output_type(0)); } for (auto out_edge : out_node->out_edges()) { + if (subgraph_ids.count(out_edge->dst()->id())) + continue; // skip internal edges; if (out_edge->src_output() == port) { out_edges.push_back(out_edge); - break; + VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" + << out_edge->src_output() << " -> " << out_edge->dst()->name() + << ":" << out_edge->dst_input(); } } } else { LOG(WARNING) << " couldn't find output node " << out_node_name; } } - VLOG(1) << "Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); + if (VLOG_IS_ON(1)) { + VLOG(1) << c_node->name() << " Input Nodes:"; + for (auto& i : input_names) { + VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); + } } auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); auto resmgr = trt_rm->getManager("TRTCalibOps"); @@ -2231,14 +2250,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( calib_res->builder_ = nullptr; tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges; + income_edges.resize(c_node->num_inputs()); for (const auto in_edge : c_node->in_edges()) { auto src = in_edge->src(); int dest_port = in_edge->dst_input(); - income_edges.emplace_back(src->name(), in_edge->src_output(), - c_node->input_type(dest_port)); + VLOG(1) << "Incoming connection " << src->name() << ":" + << in_edge->src_output() << " -> " << c_node->name() << ":" + << dest_port; + income_edges.at(dest_port) = {src->name(), in_edge->src_output(), + c_node->input_type(dest_port)}; } tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list( income_edges); + if (VLOG_IS_ON(2)) { + for (const auto& inp : input_list) { + VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " + << tensorflow::DataTypeString(inp.data_type); + } + } op_builder.Input(input_list); tensorflow::NodeDef engine_node; const char* engine_plan_data = static_cast<const char*>(engine_plan->data()); @@ -2255,13 +2284,26 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( } auto trt_engine_node = graph.AddNode(engine_node, &status); TF_RETURN_IF_ERROR(status); - for (size_t i = 0; i < out_edges.size(); i++) { - VLOG(1) << "Connecting trt_engine_node output " << i << " with " - << out_edges.at(i)->dst()->name() << " port " - << out_edges.at(i)->dst_input(); - TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, - out_edges.at(i)->dst(), - out_edges.at(i)->dst_input())); + std::map<string, int> port_map; + for (size_t t = 0; t < output_nodes.size(); t++) { + port_map.insert({output_nodes.at(t), t}); + } + for (auto& i : out_edges) { + string s(i->src()->name()); + if (i->src_output()) StrAppend(&s, ":", i->src_output()); + int out_port = port_map.at(s); + VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port + << " -> " << i->dst()->name() << ":" << i->dst_input(); + TF_RETURN_IF_ERROR( + graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); + } + for (const auto ed : trt_engine_node->in_edges()) { + VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); + } + for (const auto ed : trt_engine_node->out_edges()) { + VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); } VLOG(1) << "Segment nodes:"; for (auto& i : segment_nodes) { @@ -2332,6 +2374,7 @@ tensorflow::Status ConvertSubgraph( std::vector<string>* output_names, std::vector<tensorflow::DataType>* output_dtypes, const string& engine_name) { + std::set<string> added_tensors; for (const std::pair<int, int>& input : s.input_inds) { VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; @@ -2374,7 +2417,6 @@ tensorflow::Status ConvertSubgraph( auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); - input_dtypes->push_back(tf_dtype); nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); auto type_status = ConvertDType(tf_dtype, &dtype); @@ -2410,8 +2452,10 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) { input_tensor_name = StrCat(node_name, ":", output_idx); } - + if (added_tensors.count(input_tensor_name)) continue; + added_tensors.insert(input_tensor_name); input_names->push_back(input_tensor_name); + input_dtypes->push_back(tf_dtype); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); @@ -2435,6 +2479,7 @@ tensorflow::Status ConvertSubgraph( // Gather output metadata int trt_engine_op_output_idx = 0; + added_tensors.clear(); for (const std::pair<int, int>& output : s.output_inds) { int node_id = output.first; int output_idx = output.second; @@ -2451,6 +2496,8 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); VLOG(2) << "Output tensor name: " << tensor_name; + if (added_tensors.count(tensor_name)) continue; + added_tensors.insert(tensor_name); output_names->push_back(tensor_name); auto tensor_or_weights = converter.get_tensor(tensor_name); if (!tensor_or_weights.is_tensor()) { |