aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 09:44:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 09:44:20 -0700
commit0e3ab3ab0d511f681b954322afc2ae89c8ea7d8f (patch)
treecfd0a10f35113c42001b00d3d855e69db30b4760 /tensorflow/contrib/tensorrt/convert
parent5877baddc72e3f234f6e0a174447becd4cabc493 (diff)
parent72e085ca1701e275acec381885b519fa6b06522c (diff)
Merge pull request #22371 from aaroey:fix_zero_size_allocation
PiperOrigin-RevId: 213998222
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc6
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc19
2 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index f29f4d6deb..7ad9bf22d3 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -678,7 +678,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
// Function to construct a funcdef from the segment and add it to the graph.
tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::Graph* graph, const tensorflow::GraphDef& segment,
- const string& name) {
+ const string& engine_name) {
tensorflow::Graph sgraph(graph->flib_def());
tensorflow::GraphConstructorOptions gcopts;
TF_RETURN_IF_ERROR(
@@ -761,9 +761,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::FunctionDefLibrary fdeflib;
auto native_segment = fdeflib.add_function();
TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
- sgraph, StrCat(name, "_native_segment"), native_segment));
+ sgraph, StrCat(engine_name, "_native_segment"), native_segment));
if (VLOG_IS_ON(7)) {
- VLOG(7) << name << " Function_Def ";
+ VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
}
VLOG(1) << "Adding funcdef to graphlib";
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index c98b07ad8b..0ce891782e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -693,8 +693,15 @@ class Converter {
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
+ // We need to check the name before setting it. For Identity op where the
+ // output is the input, if its input is one of the engine input, setting
+ // the name here will overwrite engine input bindings which will cause
+ // runtime error.
if (output.is_tensor()) {
- output.tensor()->setName(output_name.c_str());
+ const char* tensor_name = output.tensor()->getName();
+ if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+ output.tensor()->setName(output_name.c_str());
+ }
}
VLOG(2) << "Adding out tensor " << output_name << ": "
<< output.DebugString();
@@ -779,12 +786,11 @@ class Converter {
// skip control nodes
if (input_name[0] == '^') continue;
string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
+ auto last = name.find_last_of(':');
// TODO(aaroey): use TensorId
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0') {
- name.erase(first);
+ if (last != string::npos && last + 2 == name.size() &&
+ name[last + 1] == '0') {
+ name.erase(last);
}
if (trt_tensors_.count(name)) {
@@ -2697,7 +2703,6 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(*logger));
builder->setMaxBatchSize(max_batch_size);
- // TODO(aaroey): use the allocator to allocate the TRT workspace.
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
#if NV_TENSORRT_MAJOR > 3
builder->setGpuAllocator(allocator);