aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc20
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index ec9dbfa13b..044c736c03 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -232,8 +233,25 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
tensorflow::tensorrt::convert::ConversionParams cp;
+
+ std::vector<string> nodes_to_preserve;
+ for (const auto& n : item.NodesToPreserve()) {
+ auto tokens = str_util::Split(n, ":");
+ string s = tokens.at(0);
+ for (int i = 1; i < tokens.size() - 1; ++i) {
+ StrAppend(&s, ":", tokens.at(i));
+ }
+ int dumm_port = -1;
+ // If the last token is not an integer, it must be part of the name.
+ // Otherwise it is port number.
+ if (tokens.size() > 1 &&
+ !strings::safe_strto32(tokens.back(), &dumm_port)) {
+ StrAppend(&s, ":", tokens.back());
+ }
+ nodes_to_preserve.push_back(s);
+ }
cp.input_graph_def = &item.graph;
- cp.output_names = &item.fetch;
+ cp.output_names = &nodes_to_preserve;
cp.max_batch_size = maximum_batch_size_;
cp.max_workspace_size_bytes = maximum_workspace_size_;
cp.output_graph_def = optimized_graph;