aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 16:40:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 16:40:19 -0700
commit4310aa0d37c40a1841a321de8dcbb9e87f6ca2b2 (patch)
treedcead328e97cb84ea73fa8a37720f6f17b199f34 /tensorflow/contrib/tensorrt/convert
parenteedee8236e7693f921723ad942baef7b61b3ceda (diff)
parent4a24f07a2c4d1f6bd9df5b7432506d1742e81da2 (diff)
Merge pull request #20851 from samikama:KeepInputs
PiperOrigin-RevId: 204828682
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 5bb0ffc797..044c736c03 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -236,7 +237,18 @@ tensorflow::Status TRTOptimizationPass::Optimize(
std::vector<string> nodes_to_preserve;
for (const auto& n : item.NodesToPreserve()) {
auto tokens = str_util::Split(n, ":");
- nodes_to_preserve.push_back(tokens.at(0));
+ string s = tokens.at(0);
+ for (int i = 1; i < tokens.size() - 1; ++i) {
+ StrAppend(&s, ":", tokens.at(i));
+ }
+ int dumm_port = -1;
+ // If the last token is not an integer, it must be part of the name.
+ // Otherwise it is port number.
+ if (tokens.size() > 1 &&
+ !strings::safe_strto32(tokens.back(), &dumm_port)) {
+ StrAppend(&s, ":", tokens.back());
+ }
+ nodes_to_preserve.push_back(s);
}
cp.input_graph_def = &item.graph;
cp.output_names = &nodes_to_preserve;