aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc97
1 files changed, 72 insertions, 25 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 96e0700862..4e4d295538 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -362,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
break;
}
case tensorflow::DataType::DT_HALF: {
- Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
- istrides, static_cast<Eigen::half*>(
- const_cast<void*>(oweights->GetValues())),
- ostrides);
+ Reorder2(
+ {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
+ istrides,
+ static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
+ ostrides);
break;
}
default:
@@ -1179,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor(
CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
- " not supported at: " +
- node_def.name());
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(tensor_l),
@@ -2138,9 +2139,7 @@ void Converter::register_op_converters() {
}
} // namespace
-tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) {
- return tensorflow::errors::Unimplemented("Not implemented yet");
-}
+
tensorflow::Status ConvertCalibrationNodeToEngineNode(
tensorflow::Graph& graph, tensorflow::Node* c_node) {
const auto ndef = c_node->def();
@@ -2164,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
for (auto n : graph.op_nodes()) {
node_maps.insert({n->name(), n});
}
+ std::set<int> subgraph_ids;
+ for (const auto internal_node : segment_nodes) {
+ subgraph_ids.insert(node_maps.at(internal_node)->id());
+ }
+ if (VLOG_IS_ON(2)) {
+ string node_names = StrCat(c_node->name(), " segment nodes= ");
+
+ for (const auto& node_name : segment_nodes) {
+ StrAppend(&node_names, node_name, ", ");
+ }
+ VLOG(2) << node_names;
+ }
+
VLOG(1) << "Output Nodes:";
std::vector<tensorflow::DataType> out_types;
std::vector<const tensorflow::Edge*> out_edges;
+
for (auto& i : output_nodes) {
auto node_port = tensorflow::str_util::Split(i, ":");
VLOG(1) << " " << i << " in graph " << node_maps.count(i);
@@ -2186,18 +2199,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
out_types.push_back(out_node->output_type(0));
}
for (auto out_edge : out_node->out_edges()) {
+ if (subgraph_ids.count(out_edge->dst()->id()))
+ continue; // skip internal edges;
if (out_edge->src_output() == port) {
out_edges.push_back(out_edge);
- break;
+ VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":"
+ << out_edge->src_output() << " -> " << out_edge->dst()->name()
+ << ":" << out_edge->dst_input();
}
}
} else {
LOG(WARNING) << " couldn't find output node " << out_node_name;
}
}
- VLOG(1) << "Input Nodes:";
- for (auto& i : input_names) {
- VLOG(1) << " " << i << " in graph " << node_maps.count(i);
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << c_node->name() << " Input Nodes:";
+ for (auto& i : input_names) {
+ VLOG(1) << " Input " << i << " in graph " << node_maps.count(i);
+ }
}
auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
auto resmgr = trt_rm->getManager("TRTCalibOps");
@@ -2231,14 +2250,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
calib_res->builder_ = nullptr;
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ income_edges.resize(c_node->num_inputs());
for (const auto in_edge : c_node->in_edges()) {
auto src = in_edge->src();
int dest_port = in_edge->dst_input();
- income_edges.emplace_back(src->name(), in_edge->src_output(),
- c_node->input_type(dest_port));
+ VLOG(1) << "Incoming connection " << src->name() << ":"
+ << in_edge->src_output() << " -> " << c_node->name() << ":"
+ << dest_port;
+ income_edges.at(dest_port) = {src->name(), in_edge->src_output(),
+ c_node->input_type(dest_port)};
}
tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
income_edges);
+ if (VLOG_IS_ON(2)) {
+ for (const auto& inp : input_list) {
+ VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " "
+ << tensorflow::DataTypeString(inp.data_type);
+ }
+ }
op_builder.Input(input_list);
tensorflow::NodeDef engine_node;
const char* engine_plan_data = static_cast<const char*>(engine_plan->data());
@@ -2255,13 +2284,26 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
}
auto trt_engine_node = graph.AddNode(engine_node, &status);
TF_RETURN_IF_ERROR(status);
- for (size_t i = 0; i < out_edges.size(); i++) {
- VLOG(1) << "Connecting trt_engine_node output " << i << " with "
- << out_edges.at(i)->dst()->name() << " port "
- << out_edges.at(i)->dst_input();
- TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i,
- out_edges.at(i)->dst(),
- out_edges.at(i)->dst_input()));
+ std::map<string, int> port_map;
+ for (size_t t = 0; t < output_nodes.size(); t++) {
+ port_map.insert({output_nodes.at(t), t});
+ }
+ for (auto& i : out_edges) {
+ string s(i->src()->name());
+ if (i->src_output()) StrAppend(&s, ":", i->src_output());
+ int out_port = port_map.at(s);
+ VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port
+ << " -> " << i->dst()->name() << ":" << i->dst_input();
+ TF_RETURN_IF_ERROR(
+ graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input()));
+ }
+ for (const auto ed : trt_engine_node->in_edges()) {
+ VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output()
+ << " -> " << ed->dst()->name() << ":" << ed->dst_input();
+ }
+ for (const auto ed : trt_engine_node->out_edges()) {
+ VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output()
+ << " -> " << ed->dst()->name() << ":" << ed->dst_input();
}
VLOG(1) << "Segment nodes:";
for (auto& i : segment_nodes) {
@@ -2332,6 +2374,7 @@ tensorflow::Status ConvertSubgraph(
std::vector<string>* output_names,
std::vector<tensorflow::DataType>* output_dtypes,
const string& engine_name) {
+ std::set<string> added_tensors;
for (const std::pair<int, int>& input : s.input_inds) {
VLOG(2) << "parsing input. Node id= " << input.first;
int node_id = input.first;
@@ -2374,7 +2417,6 @@ tensorflow::Status ConvertSubgraph(
auto op_info = op_info_vec.at(shape_inference_output_idx);
tensorflow::DataType tf_dtype = op_info.dtype();
- input_dtypes->push_back(tf_dtype);
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
auto type_status = ConvertDType(tf_dtype, &dtype);
@@ -2410,8 +2452,10 @@ tensorflow::Status ConvertSubgraph(
if (output_idx != 0) {
input_tensor_name = StrCat(node_name, ":", output_idx);
}
-
+ if (added_tensors.count(input_tensor_name)) continue;
+ added_tensors.insert(input_tensor_name);
input_names->push_back(input_tensor_name);
+ input_dtypes->push_back(tf_dtype);
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
@@ -2435,6 +2479,7 @@ tensorflow::Status ConvertSubgraph(
// Gather output metadata
int trt_engine_op_output_idx = 0;
+ added_tensors.clear();
for (const std::pair<int, int>& output : s.output_inds) {
int node_id = output.first;
int output_idx = output.second;
@@ -2451,6 +2496,8 @@ tensorflow::Status ConvertSubgraph(
if (output_idx != 0)
tensorflow::strings::StrAppend(&tensor_name, ":", output_idx);
VLOG(2) << "Output tensor name: " << tensor_name;
+ if (added_tensors.count(tensor_name)) continue;
+ added_tensors.insert(tensor_name);
output_names->push_back(tensor_name);
auto tensor_or_weights = converter.get_tensor(tensor_name);
if (!tensor_or_weights.is_tensor()) {