diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 16:40:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 16:40:19 -0700 |
commit | 4310aa0d37c40a1841a321de8dcbb9e87f6ca2b2 (patch) | |
tree | dcead328e97cb84ea73fa8a37720f6f17b199f34 /tensorflow/contrib/tensorrt/convert | |
parent | eedee8236e7693f921723ad942baef7b61b3ceda (diff) | |
parent | 4a24f07a2c4d1f6bd9df5b7432506d1742e81da2 (diff) |
Merge pull request #20851 from samikama:KeepInputs
PiperOrigin-RevId: 204828682
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 5bb0ffc797..044c736c03 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -236,7 +237,18 @@ tensorflow::Status TRTOptimizationPass::Optimize( std::vector<string> nodes_to_preserve; for (const auto& n : item.NodesToPreserve()) { auto tokens = str_util::Split(n, ":"); - nodes_to_preserve.push_back(tokens.at(0)); + string s = tokens.at(0); + for (int i = 1; i < tokens.size() - 1; ++i) { + StrAppend(&s, ":", tokens.at(i)); + } + int dumm_port = -1; + // If the last token is not an integer, it must be part of the name. + // Otherwise it is port number. + if (tokens.size() > 1 && + !strings::safe_strto32(tokens.back(), &dumm_port)) { + StrAppend(&s, ":", tokens.back()); + } + nodes_to_preserve.push_back(s); } cp.input_graph_def = &item.graph; cp.output_names = &nodes_to_preserve; |