diff options
author | 2018-09-21 09:44:20 -0700 | |
---|---|---|
committer | 2018-09-21 09:44:20 -0700 | |
commit | 0e3ab3ab0d511f681b954322afc2ae89c8ea7d8f (patch) | |
tree | cfd0a10f35113c42001b00d3d855e69db30b4760 /tensorflow/contrib/tensorrt/convert | |
parent | 5877baddc72e3f234f6e0a174447becd4cabc493 (diff) | |
parent | 72e085ca1701e275acec381885b519fa6b06522c (diff) |
Merge pull request #22371 from aaroey:fix_zero_size_allocation
PiperOrigin-RevId: 213998222
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 19 |
2 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index f29f4d6deb..7ad9bf22d3 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -678,7 +678,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, // Function to construct a funcdef from the segment and add it to the graph. tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( tensorflow::Graph* graph, const tensorflow::GraphDef& segment, - const string& name) { + const string& engine_name) { tensorflow::Graph sgraph(graph->flib_def()); tensorflow::GraphConstructorOptions gcopts; TF_RETURN_IF_ERROR( @@ -761,9 +761,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( tensorflow::FunctionDefLibrary fdeflib; auto native_segment = fdeflib.add_function(); TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( - sgraph, StrCat(name, "_native_segment"), native_segment)); + sgraph, StrCat(engine_name, "_native_segment"), native_segment)); if (VLOG_IS_ON(7)) { - VLOG(7) << name << " Function_Def "; + VLOG(7) << engine_name << " Function_Def "; VLOG(7) << native_segment->DebugString(); } VLOG(1) << "Adding funcdef to graphlib"; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index c98b07ad8b..0ce891782e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -693,8 +693,15 @@ class Converter { // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); + // We need to check the name before setting it. For Identity op where the + // output is the input, if its input is one of the engine input, setting + // the name here will overwrite engine input bindings which will cause + // runtime error. if (output.is_tensor()) { - output.tensor()->setName(output_name.c_str()); + const char* tensor_name = output.tensor()->getName(); + if (tensor_name == nullptr || std::strlen(tensor_name) == 0) { + output.tensor()->setName(output_name.c_str()); + } } VLOG(2) << "Adding out tensor " << output_name << ": " << output.DebugString(); @@ -779,12 +786,11 @@ class Converter { // skip control nodes if (input_name[0] == '^') continue; string name = input_name; - auto first = name.find_first_of(':'); - // TODO(aaroey): why removing the colon but not the zero? A bug? + auto last = name.find_last_of(':'); // TODO(aaroey): use TensorId - if (first != string::npos && first + 2 == name.size() && - name[first + 1] == '0') { - name.erase(first); + if (last != string::npos && last + 2 == name.size() && + name[last + 1] == '0') { + name.erase(last); } if (trt_tensors_.count(name)) { @@ -2697,7 +2703,6 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType<nvinfer1::IBuilder> builder( nvinfer1::createInferBuilder(*logger)); builder->setMaxBatchSize(max_batch_size); - // TODO(aaroey): use the allocator to allocate the TRT workspace. builder->setMaxWorkspaceSize(max_workspace_size_bytes); #if NV_TENSORRT_MAJOR > 3 builder->setGpuAllocator(allocator); |