diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index ec9dbfa13b..044c736c03 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -232,8 +233,25 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; + + std::vector<string> nodes_to_preserve; + for (const auto& n : item.NodesToPreserve()) { + auto tokens = str_util::Split(n, ":"); + string s = tokens.at(0); + for (int i = 1; i < tokens.size() - 1; ++i) { + StrAppend(&s, ":", tokens.at(i)); + } + int dumm_port = -1; + // If the last token is not an integer, it must be part of the name. + // Otherwise it is port number. + if (tokens.size() > 1 && + !strings::safe_strto32(tokens.back(), &dumm_port)) { + StrAppend(&s, ":", tokens.back()); + } + nodes_to_preserve.push_back(s); + } cp.input_graph_def = &item.graph; - cp.output_names = &item.fetch; + cp.output_names = &nodes_to_preserve; cp.max_batch_size = maximum_batch_size_; cp.max_workspace_size_bytes = maximum_workspace_size_; cp.output_graph_def = optimized_graph; |