aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Guangda Lai <31743510+aaroey@users.noreply.github.com>2018-09-20 09:56:07 -0700
committerGravatar Guangda Lai <31743510+aaroey@users.noreply.github.com>2018-09-20 09:56:07 -0700
commit13fac9da3820d0dda504eac43a0bd59876742262 (patch)
treee66df7222b925367b4df5c2aabb7292de59a2b82 /tensorflow/contrib/tensorrt
parent5f05a18c576ba89a7bce5f2ed5c7104bc158d8f1 (diff)
Set back the ITensor name, but conditionally.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 6283bd2300..0ce891782e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -693,10 +693,16 @@ class Converter {
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
- // We should not call output.tensor()->setName(), since the name may have
- // already been set before (e.g. for Identity op where the output is the
- // input, if its input is one of the engine input, setting the name here
- // will overwrite engine input bindings which will cause runtime error).
+ // We need to check the name before setting it. For Identity op where the
+ // output is the input, if its input is one of the engine input, setting
+ // the name here will overwrite engine input bindings which will cause
+ // runtime error.
+ if (output.is_tensor()) {
+ const char* tensor_name = output.tensor()->getName();
+ if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+ output.tensor()->setName(output_name.c_str());
+ }
+ }
VLOG(2) << "Adding out tensor " << output_name << ": "
<< output.DebugString();
if (!trt_tensors_.insert({output_name, output}).second) {
@@ -1301,6 +1307,7 @@ tensorflow::Status ConvertConv2DHelper(
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
@@ -1546,6 +1553,7 @@ tensorflow::Status ConvertPool(Converter& ctx,
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
if (data_format == "NHWC") {