aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 09:47:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 09:47:26 -0700
commit97ae13e08d5fffa21ea52b016249ef0809005d49 (patch)
tree11bdea347a9e2a4751e477eb41c1941e7aec2602 /tensorflow
parent74fce066580ca286b2c776a64ab624f12a473b28 (diff)
parent511ce2e7eb8f220c443f09382d09d14f0758e8ba (diff)
Merge pull request #20794 from samikama:KeepInputs
PiperOrigin-RevId: 204756546
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index ec9dbfa13b..5bb0ffc797 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -232,8 +232,14 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
tensorflow::tensorrt::convert::ConversionParams cp;
+
+ std::vector<string> nodes_to_preserve;
+ for (const auto& n : item.NodesToPreserve()) {
+ auto tokens = str_util::Split(n, ":");
+ nodes_to_preserve.push_back(tokens.at(0));
+ }
cp.input_graph_def = &item.graph;
- cp.output_names = &item.fetch;
+ cp.output_names = &nodes_to_preserve;
cp.max_batch_size = maximum_batch_size_;
cp.max_workspace_size_bytes = maximum_workspace_size_;
cp.output_graph_def = optimized_graph;