diff options
author | Guangda Lai <31743510+aaroey@users.noreply.github.com> | 2018-09-20 09:56:07 -0700 |
---|---|---|
committer | Guangda Lai <31743510+aaroey@users.noreply.github.com> | 2018-09-20 09:56:07 -0700 |
commit | 13fac9da3820d0dda504eac43a0bd59876742262 (patch) | |
tree | e66df7222b925367b4df5c2aabb7292de59a2b82 /tensorflow/contrib/tensorrt | |
parent | 5f05a18c576ba89a7bce5f2ed5c7104bc158d8f1 (diff) |
Set back the ITensor name, but conditionally.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 6283bd2300..0ce891782e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -693,10 +693,16 @@ class Converter { // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); - // We should not call output.tensor()->setName(), since the name may have - // already been set before (e.g. for Identity op where the output is the - // input, if its input is one of the engine input, setting the name here - // will overwrite engine input bindings which will cause runtime error). + // We need to check the name before setting it. For Identity op where the + // output is the input, if its input is one of the engine input, setting + // the name here will overwrite engine input bindings which will cause + // runtime error. + if (output.is_tensor()) { + const char* tensor_name = output.tensor()->getName(); + if (tensor_name == nullptr || std::strlen(tensor_name) == 0) { + output.tensor()->setName(output_name.c_str()); + } + } VLOG(2) << "Adding out tensor " << output_name << ": " << output.DebugString(); if (!trt_tensors_.insert({output_name, output}).second) { @@ -1301,6 +1307,7 @@ tensorflow::Status ConvertConv2DHelper( layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); nvinfer1::ITensor* output_tensor = layer->getOutput(0); VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions()); @@ -1546,6 +1553,7 @@ tensorflow::Status ConvertPool(Converter& ctx, layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { |