diff options
author | 2018-07-16 09:47:20 -0700 | |
---|---|---|
committer | 2018-07-16 09:47:26 -0700 | |
commit | 97ae13e08d5fffa21ea52b016249ef0809005d49 (patch) | |
tree | 11bdea347a9e2a4751e477eb41c1941e7aec2602 | |
parent | 74fce066580ca286b2c776a64ab624f12a473b28 (diff) | |
parent | 511ce2e7eb8f220c443f09382d09d14f0758e8ba (diff) |
Merge pull request #20794 from samikama:KeepInputs
PiperOrigin-RevId: 204756546
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index ec9dbfa13b..5bb0ffc797 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -232,8 +232,14 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; + + std::vector<string> nodes_to_preserve; + for (const auto& n : item.NodesToPreserve()) { + auto tokens = str_util::Split(n, ":"); + nodes_to_preserve.push_back(tokens.at(0)); + } cp.input_graph_def = &item.graph; - cp.output_names = &item.fetch; + cp.output_names = &nodes_to_preserve; cp.max_batch_size = maximum_batch_size_; cp.max_workspace_size_bytes = maximum_workspace_size_; cp.output_graph_def = optimized_graph; |